From 637c30753d48a469a852a6ab05fd8f0b1d28f29d Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Thu, 29 Jan 2026 21:59:21 +0530 Subject: [PATCH 1/3] relax scan broadcastability checks --- pytensor/scan/op.py | 25 ++++++++++++++++++++----- tests/scan/test_basic.py | 19 ++++++++++++++++--- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 553c538296..764dd79184 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -180,7 +180,7 @@ def check_broadcast(v1, v2): for n, (b1, b2) in enumerate( zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:], strict=False) ): - if b1 != b2: + if b1 and not b2: a1 = n + size - v1.type.ndim + 1 a2 = n + size - v2.type.ndim + 1 raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2)) @@ -712,11 +712,26 @@ def validate_inner_graph(self): for inner_iidx, inner_oidx in product(inner_iidxs, inner_oidxs): type_input = self.inner_inputs[inner_iidx].type type_output = self.inner_outputs[inner_oidx].type - if ( - # TODO: Use the `Type` interface for this - type_input.dtype != type_output.dtype - or type_input.broadcastable != type_output.broadcastable + if type_input.dtype != type_output.dtype: + raise TypeError( + "Inconsistency in the inner graph of " + f"scan '{self.name}' : an input and an output are " + "associated with the same recurrent state " + "and should have compatible types but have " + f"type '{type_input}' and '{type_output}' respectively." + ) + if isinstance(type_input, TensorType) and isinstance( + type_output, TensorType ): + if not type_input.is_super(type_output): + raise TypeError( + "Inconsistency in the inner graph of " + f"scan '{self.name}' : an input and an output are " + "associated with the same recurrent state " + "and should have compatible types but have " + f"type '{type_input}' and '{type_output}' respectively." + ) + elif type_input != type_output: raise TypeError( "Inconsistency in the inner graph of " f"scan '{self.name}' : an input and an output are " diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index f121dc9e58..5146798c92 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -37,7 +37,7 @@ from pytensor.link.vm import VMLinker from pytensor.raise_op import assert_op from pytensor.scan.basic import scan -from pytensor.scan.op import Scan, ScanInfo +from pytensor.scan.op import Scan, ScanInfo, check_broadcast from pytensor.scan.utils import until from pytensor.tensor import as_tensor from pytensor.tensor.math import all as pt_all @@ -2544,9 +2544,22 @@ def test_inconsistent_broadcast_error(): outputs_info=[dict(initial=initial_x)], return_updates=False, ) - # Error, because the broadcast patterns are inconsistent. + # This should now work with relaxed broadcast checks + g = grad(y.sum(), x) + assert g is not None + + +def test_check_broadcast_allows_more_specific_inner(): + outer = TensorType(config.floatX, shape=(None,))("outer") + inner = TensorType(config.floatX, shape=(None, 1))("inner") + check_broadcast(outer, inner) + + +def test_check_broadcast_rejects_more_specific_outer(): + outer = TensorType(config.floatX, shape=(None, 1))("outer") + inner = TensorType(config.floatX, shape=(None,))("inner") with pytest.raises(TypeError): - grad(y.sum(), x) + check_broadcast(outer, inner) def test_missing_input_error(): From 800f3cc21ba484d7c7099bc76333b72c4c57c54a Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Fri, 30 Jan 2026 17:48:51 +0530 Subject: [PATCH 2/3] validate inner graph types via filter_variable --- pytensor/scan/op.py | 54 ++++++---------------------------------- tests/scan/test_basic.py | 20 ++++++++------- 2 files changed, 18 insertions(+), 56 deletions(-) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 764dd79184..b56984f578 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -153,39 +153,6 @@ ) -def check_broadcast(v1, v2): - """Checks that the broadcast pattern of v1 and v2. - - Controls that the broadcast pattern of the variable provided as - input to `scan` matches the broadcast pattern provided in - `output_info`. It raises an error when they don't match. The - typical case is when the user provides either the input or the - `output_info` (but not both) with a dimension fixed to 1, - which may wrongly be interpreted as broadcastable. - - """ - if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType): - return - - msg = ( - "The broadcast pattern of the output of scan (%s) is " - "inconsistent with the one provided in `output_info` " - "(%s). The output on axis %d is `%r`, but it is `%r` on " - "axis %d in `output_info`. This can happen if one of the " - "dimension is fixed to 1 in the input, while it is still " - "variable in the output, or vice-verca. You have to make " - "them consistent, e.g. using pytensor.tensor.specify_broadcastable." - ) - size = min(v1.type.ndim, v2.type.ndim) - for n, (b1, b2) in enumerate( - zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:], strict=False) - ): - if b1 and not b2: - a1 = n + size - v1.type.ndim + 1 - a2 = n + size - v2.type.ndim + 1 - raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2)) - - def copy_var_format(var, as_var): """ This functions ensures that ``var`` has the same dtype as ``as_var`` as @@ -720,25 +687,19 @@ def validate_inner_graph(self): "and should have compatible types but have " f"type '{type_input}' and '{type_output}' respectively." ) - if isinstance(type_input, TensorType) and isinstance( - type_output, TensorType - ): - if not type_input.is_super(type_output): - raise TypeError( - "Inconsistency in the inner graph of " - f"scan '{self.name}' : an input and an output are " - "associated with the same recurrent state " - "and should have compatible types but have " - f"type '{type_input}' and '{type_output}' respectively." - ) - elif type_input != type_output: + + try: + type_input.filter_variable( + self.inner_outputs[inner_oidx], allow_convert=True + ) + except Exception as e: raise TypeError( "Inconsistency in the inner graph of " f"scan '{self.name}' : an input and an output are " "associated with the same recurrent state " "and should have compatible types but have " f"type '{type_input}' and '{type_output}' respectively." - ) + ) from e class Scan(Op, ScanMethodsMixin, HasInnerGraph): @@ -1030,7 +991,6 @@ def make_node(self, *inputs): for inner_seq, outer_seq in zip( self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs), strict=True ): - check_broadcast(outer_seq, inner_seq) new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq)) argoffset += len(self.outer_seqs(inputs)) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 5146798c92..4077eaef5b 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -37,7 +37,7 @@ from pytensor.link.vm import VMLinker from pytensor.raise_op import assert_op from pytensor.scan.basic import scan -from pytensor.scan.op import Scan, ScanInfo, check_broadcast +from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.utils import until from pytensor.tensor import as_tensor from pytensor.tensor.math import all as pt_all @@ -2549,17 +2549,19 @@ def test_inconsistent_broadcast_error(): assert g is not None -def test_check_broadcast_allows_more_specific_inner(): - outer = TensorType(config.floatX, shape=(None,))("outer") - inner = TensorType(config.floatX, shape=(None, 1))("inner") - check_broadcast(outer, inner) +def test_filter_variable_allows_inner_more_specific(): + outer_t = TensorType(config.floatX, shape=(None,)) + inner_t = TensorType(config.floatX, shape=(None, 1)) + inner_var = inner_t() + outer_t.filter_variable(inner_var, allow_convert=True) -def test_check_broadcast_rejects_more_specific_outer(): - outer = TensorType(config.floatX, shape=(None, 1))("outer") - inner = TensorType(config.floatX, shape=(None,))("inner") +def test_filter_variable_rejects_incompatible_static_shapes(): + outer_t = TensorType(config.floatX, shape=(5,)) + inner_t = TensorType(config.floatX, shape=(2,)) + inner_var = inner_t() with pytest.raises(TypeError): - check_broadcast(outer, inner) + outer_t.filter_variable(inner_var, allow_convert=True) def test_missing_input_error(): From 59afea5a741f02b0382582010aa432ba6f1324c7 Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Sat, 31 Jan 2026 15:59:11 +0530 Subject: [PATCH 3/3] fixing tests --- tests/scan/test_basic.py | 38 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 4077eaef5b..59a4022ce3 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -2535,34 +2535,22 @@ def f(x, y): scan(f, x, outputs_info) -def test_inconsistent_broadcast_error(): - x = tensor3() +def test_relaxed_broadcast_allows_observed_more_specific_and_grad(): initial_x = pt.constant(np.zeros((1, 10))) - y = scan( - fn=lambda x, prev_x: x + prev_x, - sequences=x, - outputs_info=[dict(initial=initial_x)], - return_updates=False, - ) - # This should now work with relaxed broadcast checks - g = grad(y.sum(), x) - assert g is not None - - -def test_filter_variable_allows_inner_more_specific(): - outer_t = TensorType(config.floatX, shape=(None,)) - inner_t = TensorType(config.floatX, shape=(None, 1)) - inner_var = inner_t() - outer_t.filter_variable(inner_var, allow_convert=True) + def get_sum_of_grad(inp): + y = scan( + fn=lambda x_i, prev_x: x_i + prev_x, + sequences=inp, + outputs_info=[dict(initial=initial_x)], + return_updates=False, + ) + return y.sum() -def test_filter_variable_rejects_incompatible_static_shapes(): - outer_t = TensorType(config.floatX, shape=(5,)) - inner_t = TensorType(config.floatX, shape=(2,)) - inner_var = inner_t() - with pytest.raises(TypeError): - outer_t.filter_variable(inner_var, allow_convert=True) - + floatX = config.floatX + rng = np.random.default_rng(utt.fetch_seed()) + sample = rng.random((2, 1, 10)).astype(floatX) + utt.verify_grad(get_sum_of_grad, [sample]) def test_missing_input_error(): c = shared(0.0)