Skip to content
22 changes: 20 additions & 2 deletions pytensor/tensor/rewriting/reshape.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.reshape import JoinDims, SplitDims
from pytensor.tensor.rewriting.basic import register_canonicalize


@register_canonicalize
@node_rewriter([SplitDims])
def local_split_dims_to_reshape(fgraph, node):
def local_split_dims(fgraph, node):
"""
Canonicalize SplitDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends).
Special case: if shape is (0,), converts to squeeze instead.
"""

x, shape = node.inputs
axis = node.op.axis

# Special case: empty shape -> squeeze
if shape.type.shape == (0,):
squeezed_x = squeeze(x, axis=axis)
copy_stack_trace(x, squeezed_x)
return [squeezed_x]

output_shape = [
*x.shape[:axis],
*shape,
Expand All @@ -28,7 +37,7 @@ def local_split_dims_to_reshape(fgraph, node):

@register_canonicalize
@node_rewriter([JoinDims])
def local_join_dims_to_reshape(fgraph, node):
def local_join_dims(fgraph, node):
"""
Canonicalize JoinDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends).
"""
Expand All @@ -38,6 +47,15 @@ def local_join_dims_to_reshape(fgraph, node):
start_axis = op.start_axis
n_axes = op.n_axes

if n_axes == 0:
expanded_x = expand_dims(x, axis=node.op.start_axis)
copy_stack_trace(x, expanded_x)
return [expanded_x]

if n_axes == 1:
copy_stack_trace(x, x)
return [x]

output_shape = [
*x.shape[:start_axis],
-1,
Expand Down
59 changes: 59 additions & 0 deletions tests/tensor/rewriting/test_reshape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from pytensor.graph import FunctionGraph, rewrite_graph
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims
from pytensor.tensor.shape import Reshape
from pytensor.tensor.type import tensor
from tests.unittest_tools import assert_equal_computations


def test_local_split_dims_to_reshape():
Expand Down Expand Up @@ -32,3 +35,59 @@ def test_local_join_dims_to_reshape():
assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0
assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1
assert fg.outputs[0].type.shape == (2, 10, 3)


def test_local_join_dims_noop():
"""Test that join_dims with n_axes=1 becomes identity (no-op)."""
x = tensor("x", shape=(2, 3, 4))
x_join = join_dims(x, start_axis=1, n_axes=1)

fg = FunctionGraph(inputs=[x], outputs=[x_join])

# Before rewrite: should have 1 JoinDims node
assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1

rewrite_graph(fg, include=("canonicalize",))

# Output should be equivalent to input (identity rewrite)
assert_equal_computations([fg.outputs[0]], [x], in_xs=[fg.outputs[0]], in_ys=[x])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The in_xs, in_ys is fishy

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when i take out the in_xs and in_ys, got this error:

E       AssertionError: equal_computations failed
E       
E       Rewritten:
E       x [id A] <Tensor3(float64, shape=(2, 3, 4))>
E       
E       Expected:
E       x [id A] <Tensor3(float64, shape=(2, 3, 4))>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to create FunctionGraph, pass the variable directly to rewrite_graph(..., clone=False) (if clone=False is not the default already). That should mean the x before and after are the same variable

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reading more into the assert_equal_computations and equal_computations to understand what the in_xs and in_ys do

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xs and ys in are for cases where you have the same graph but the root variables (like x) are different. This can happen in this case if x is cloned by the FunctionGraph. But even in that case you should just say xs_in=[x], ys_in=[fg.inputs[0]]. (basically saying the original x is not the one in the fgraph.

But I don't think you need fgraph, or if you do you can pass copy_inputs=False or clone=False so that it uses the original x as well

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clone=False

how do i write the test without fgraph? can i do rewrite_graph without the fgraph? then how does rewrite work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you can call rewrite graph on a variable. Internally it creates an fg for you, but I think it doesn't copy x. Give it a try, I could be wrong.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you can call rewrite graph on a variable. Internally it creates an fg for you, but I think it doesn't copy x. Give it a try, I could be wrong.

Copy link

@mbaldourw mbaldourw Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is what you mean

rewrite_graph(x, include=("canonicalize",))

I got the following error:

>       assert_equal_computations([x_join], [x])
E       AssertionError: equal_computations failed
E       
E       Rewritten:
E       JoinDims{start_axis=1, n_axes=1} [id A] <Tensor3(float64, shape=(2, 3, 4))>
E        └─ x [id B] <Tensor3(float64, shape=(2, 3, 4))>
E       
E       Expected:
E       x [id A] <Tensor3(float64, shape=(2, 3, 4))>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to store the output of rewrite graph, that's the one you want to pass to the comparison function. And you want to pass the output to rewrite_graph not the input.

Something like

out = join_dims(x, ...)
rewritten_out = rewrite_graph(out, ...)
expected_out = ...
assert_equal_computations(rewritten_out, expected_out)



def test_local_join_dims_expand():
"""Test that join_dims with n_axes=0 becomes expand_dims."""
x = tensor("x", shape=(2, 3, 4))
x_join = join_dims(x, start_axis=1, n_axes=0)

fg = FunctionGraph(inputs=[x], outputs=[x_join])

# Before rewrite: should have 1 JoinDims node
assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1

rewrite_graph(fg, include=("canonicalize",))

# Output shape should be (2, 1, 3, 4) - new dimension of size 1 inserted at axis 1
expected = expand_dims(x, axis=1)
assert_equal_computations(
[fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected]
)


def test_local_split_dims_to_reshape_squeeze_case():
"""Test that split_dims with shape tensor of static shape (0,) becomes squeeze via merged rewrite."""
x = tensor("x", shape=(2, 1, 3, 4))
# Create a tensor variable with static shape (0,)
empty_shape_var = tensor("empty_shape", shape=(0,), dtype="int32")
x_split = split_dims(x, axis=1, shape=empty_shape_var)

fg = FunctionGraph(inputs=[x, empty_shape_var], outputs=[x_split])

# Before rewrite: should have 1 SplitDims node
assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1

rewrite_graph(fg, include=("canonicalize",))

# Output shape should be (2, 3, 4) - dimension 1 removed
expected = squeeze(x, axis=1)
assert_equal_computations(
[fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected]
)