diff --git a/pytensor/tensor/rewriting/reshape.py b/pytensor/tensor/rewriting/reshape.py index ab330551f3..3fc9372f14 100644 --- a/pytensor/tensor/rewriting/reshape.py +++ b/pytensor/tensor/rewriting/reshape.py @@ -1,19 +1,28 @@ from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace +from pytensor.tensor.basic import expand_dims +from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.reshape import JoinDims, SplitDims from pytensor.tensor.rewriting.basic import register_canonicalize @register_canonicalize @node_rewriter([SplitDims]) -def local_split_dims_to_reshape(fgraph, node): +def local_split_dims(fgraph, node): """ Canonicalize SplitDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends). + Special case: if shape is (0,), converts to squeeze instead. """ x, shape = node.inputs axis = node.op.axis + # Special case: empty shape -> squeeze + if shape.type.shape == (0,): + squeezed_x = squeeze(x, axis=axis) + copy_stack_trace(x, squeezed_x) + return [squeezed_x] + output_shape = [ *x.shape[:axis], *shape, @@ -28,7 +37,7 @@ def local_split_dims_to_reshape(fgraph, node): @register_canonicalize @node_rewriter([JoinDims]) -def local_join_dims_to_reshape(fgraph, node): +def local_join_dims(fgraph, node): """ Canonicalize JoinDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends). """ @@ -38,6 +47,15 @@ def local_join_dims_to_reshape(fgraph, node): start_axis = op.start_axis n_axes = op.n_axes + if n_axes == 0: + expanded_x = expand_dims(x, axis=node.op.start_axis) + copy_stack_trace(x, expanded_x) + return [expanded_x] + + if n_axes == 1: + copy_stack_trace(x, x) + return [x] + output_shape = [ *x.shape[:start_axis], -1, diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index d18e0b6419..2d761875fb 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -1,7 +1,10 @@ from pytensor.graph import FunctionGraph, rewrite_graph +from pytensor.tensor.basic import expand_dims +from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims from pytensor.tensor.shape import Reshape from pytensor.tensor.type import tensor +from tests.unittest_tools import assert_equal_computations def test_local_split_dims_to_reshape(): @@ -32,3 +35,59 @@ def test_local_join_dims_to_reshape(): assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 assert fg.outputs[0].type.shape == (2, 10, 3) + + +def test_local_join_dims_noop(): + """Test that join_dims with n_axes=1 becomes identity (no-op).""" + x = tensor("x", shape=(2, 3, 4)) + x_join = join_dims(x, start_axis=1, n_axes=1) + + fg = FunctionGraph(inputs=[x], outputs=[x_join]) + + # Before rewrite: should have 1 JoinDims node + assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1 + + rewrite_graph(fg, include=("canonicalize",)) + + # Output should be equivalent to input (identity rewrite) + assert_equal_computations([fg.outputs[0]], [x], in_xs=[fg.outputs[0]], in_ys=[x]) + + +def test_local_join_dims_expand(): + """Test that join_dims with n_axes=0 becomes expand_dims.""" + x = tensor("x", shape=(2, 3, 4)) + x_join = join_dims(x, start_axis=1, n_axes=0) + + fg = FunctionGraph(inputs=[x], outputs=[x_join]) + + # Before rewrite: should have 1 JoinDims node + assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1 + + rewrite_graph(fg, include=("canonicalize",)) + + # Output shape should be (2, 1, 3, 4) - new dimension of size 1 inserted at axis 1 + expected = expand_dims(x, axis=1) + assert_equal_computations( + [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] + ) + + +def test_local_split_dims_to_reshape_squeeze_case(): + """Test that split_dims with shape tensor of static shape (0,) becomes squeeze via merged rewrite.""" + x = tensor("x", shape=(2, 1, 3, 4)) + # Create a tensor variable with static shape (0,) + empty_shape_var = tensor("empty_shape", shape=(0,), dtype="int32") + x_split = split_dims(x, axis=1, shape=empty_shape_var) + + fg = FunctionGraph(inputs=[x, empty_shape_var], outputs=[x_split]) + + # Before rewrite: should have 1 SplitDims node + assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1 + + rewrite_graph(fg, include=("canonicalize",)) + + # Output shape should be (2, 3, 4) - dimension 1 removed + expected = squeeze(x, axis=1) + assert_equal_computations( + [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] + )