Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 14 additions & 39 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,39 +153,6 @@
)


def check_broadcast(v1, v2):
"""Checks that the broadcast pattern of v1 and v2.

Controls that the broadcast pattern of the variable provided as
input to `scan` matches the broadcast pattern provided in
`output_info`. It raises an error when they don't match. The
typical case is when the user provides either the input or the
`output_info` (but not both) with a dimension fixed to 1,
which may wrongly be interpreted as broadcastable.

"""
if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
return

msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
"(%s). The output on axis %d is `%r`, but it is `%r` on "
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using pytensor.tensor.specify_broadcastable."
)
size = min(v1.type.ndim, v2.type.ndim)
for n, (b1, b2) in enumerate(
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:], strict=False)
):
if b1 != b2:
a1 = n + size - v1.type.ndim + 1
a2 = n + size - v2.type.ndim + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))


def copy_var_format(var, as_var):
"""
This functions ensures that ``var`` has the same dtype as ``as_var`` as
Expand Down Expand Up @@ -712,11 +679,7 @@ def validate_inner_graph(self):
for inner_iidx, inner_oidx in product(inner_iidxs, inner_oidxs):
type_input = self.inner_inputs[inner_iidx].type
type_output = self.inner_outputs[inner_oidx].type
if (
# TODO: Use the `Type` interface for this
type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
):
if type_input.dtype != type_output.dtype:
Copy link
Member

@ricardoV94 ricardoV94 Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is covered by the new try/except below, so we don't need it separately

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldnt removing the dtype check change semantics? type api will be allowed to perform dtype conversions instead of raising

raise TypeError(
"Inconsistency in the inner graph of "
f"scan '{self.name}' : an input and an output are "
Expand All @@ -725,6 +688,19 @@ def validate_inner_graph(self):
f"type '{type_input}' and '{type_output}' respectively."
)

try:
type_input.filter_variable(
self.inner_outputs[inner_oidx], allow_convert=True
)
except Exception as e:
raise TypeError(
"Inconsistency in the inner graph of "
f"scan '{self.name}' : an input and an output are "
"associated with the same recurrent state "
"and should have compatible types but have "
f"type '{type_input}' and '{type_output}' respectively."
) from e


class Scan(Op, ScanMethodsMixin, HasInnerGraph):
r"""An `Op` implementing `for` and `while` loops.
Expand Down Expand Up @@ -1015,7 +991,6 @@ def make_node(self, *inputs):
for inner_seq, outer_seq in zip(
self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs), strict=True
):
check_broadcast(outer_seq, inner_seq)
new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))

argoffset += len(self.outer_seqs(inputs))
Expand Down
25 changes: 14 additions & 11 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2535,19 +2535,22 @@ def f(x, y):
scan(f, x, outputs_info)


def test_inconsistent_broadcast_error():
x = tensor3()
def test_relaxed_broadcast_allows_observed_more_specific_and_grad():
initial_x = pt.constant(np.zeros((1, 10)))
y = scan(
fn=lambda x, prev_x: x + prev_x,
sequences=x,
outputs_info=[dict(initial=initial_x)],
return_updates=False,
)
# Error, because the broadcast patterns are inconsistent.
with pytest.raises(TypeError):
grad(y.sum(), x)

def get_sum_of_grad(inp):
y = scan(
fn=lambda x_i, prev_x: x_i + prev_x,
sequences=inp,
outputs_info=[dict(initial=initial_x)],
return_updates=False,
)
return y.sum()

floatX = config.floatX
rng = np.random.default_rng(utt.fetch_seed())
sample = rng.random((2, 1, 10)).astype(floatX)
utt.verify_grad(get_sum_of_grad, [sample])

def test_missing_input_error():
c = shared(0.0)
Expand Down