diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 553c538296..b56984f578 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -153,39 +153,6 @@ ) -def check_broadcast(v1, v2): - """Checks that the broadcast pattern of v1 and v2. - - Controls that the broadcast pattern of the variable provided as - input to `scan` matches the broadcast pattern provided in - `output_info`. It raises an error when they don't match. The - typical case is when the user provides either the input or the - `output_info` (but not both) with a dimension fixed to 1, - which may wrongly be interpreted as broadcastable. - - """ - if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType): - return - - msg = ( - "The broadcast pattern of the output of scan (%s) is " - "inconsistent with the one provided in `output_info` " - "(%s). The output on axis %d is `%r`, but it is `%r` on " - "axis %d in `output_info`. This can happen if one of the " - "dimension is fixed to 1 in the input, while it is still " - "variable in the output, or vice-verca. You have to make " - "them consistent, e.g. using pytensor.tensor.specify_broadcastable." - ) - size = min(v1.type.ndim, v2.type.ndim) - for n, (b1, b2) in enumerate( - zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:], strict=False) - ): - if b1 != b2: - a1 = n + size - v1.type.ndim + 1 - a2 = n + size - v2.type.ndim + 1 - raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2)) - - def copy_var_format(var, as_var): """ This functions ensures that ``var`` has the same dtype as ``as_var`` as @@ -712,11 +679,7 @@ def validate_inner_graph(self): for inner_iidx, inner_oidx in product(inner_iidxs, inner_oidxs): type_input = self.inner_inputs[inner_iidx].type type_output = self.inner_outputs[inner_oidx].type - if ( - # TODO: Use the `Type` interface for this - type_input.dtype != type_output.dtype - or type_input.broadcastable != type_output.broadcastable - ): + if type_input.dtype != type_output.dtype: raise TypeError( "Inconsistency in the inner graph of " f"scan '{self.name}' : an input and an output are " @@ -725,6 +688,19 @@ def validate_inner_graph(self): f"type '{type_input}' and '{type_output}' respectively." ) + try: + type_input.filter_variable( + self.inner_outputs[inner_oidx], allow_convert=True + ) + except Exception as e: + raise TypeError( + "Inconsistency in the inner graph of " + f"scan '{self.name}' : an input and an output are " + "associated with the same recurrent state " + "and should have compatible types but have " + f"type '{type_input}' and '{type_output}' respectively." + ) from e + class Scan(Op, ScanMethodsMixin, HasInnerGraph): r"""An `Op` implementing `for` and `while` loops. @@ -1015,7 +991,6 @@ def make_node(self, *inputs): for inner_seq, outer_seq in zip( self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs), strict=True ): - check_broadcast(outer_seq, inner_seq) new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq)) argoffset += len(self.outer_seqs(inputs)) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index f121dc9e58..59a4022ce3 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -2535,19 +2535,22 @@ def f(x, y): scan(f, x, outputs_info) -def test_inconsistent_broadcast_error(): - x = tensor3() +def test_relaxed_broadcast_allows_observed_more_specific_and_grad(): initial_x = pt.constant(np.zeros((1, 10))) - y = scan( - fn=lambda x, prev_x: x + prev_x, - sequences=x, - outputs_info=[dict(initial=initial_x)], - return_updates=False, - ) - # Error, because the broadcast patterns are inconsistent. - with pytest.raises(TypeError): - grad(y.sum(), x) + def get_sum_of_grad(inp): + y = scan( + fn=lambda x_i, prev_x: x_i + prev_x, + sequences=inp, + outputs_info=[dict(initial=initial_x)], + return_updates=False, + ) + return y.sum() + + floatX = config.floatX + rng = np.random.default_rng(utt.fetch_seed()) + sample = rng.random((2, 1, 10)).astype(floatX) + utt.verify_grad(get_sum_of_grad, [sample]) def test_missing_input_error(): c = shared(0.0)