From 04e391e88b7a7bb4e361d00e57b0c470f1973642 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Tue, 13 Jan 2026 16:22:42 -0500 Subject: [PATCH 01/12] Add canonicalization rewrites for JoinDims/SplitDims Optimize JoinDims and SplitDims by canonicalizing to simpler operations (identity, expand_dims, squeeze). Partial fixes #1843 --- pytensor/tensor/rewriting/reshape.py | 77 +++++++++++++++++++- tests/tensor/rewriting/test_reshape.py | 99 ++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/reshape.py b/pytensor/tensor/rewriting/reshape.py index ab330551f3..38b12e2745 100644 --- a/pytensor/tensor/rewriting/reshape.py +++ b/pytensor/tensor/rewriting/reshape.py @@ -1,9 +1,51 @@ -from pytensor.graph import node_rewriter +from pytensor.graph import Constant, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace +from pytensor.tensor.basic import expand_dims +from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.reshape import JoinDims, SplitDims from pytensor.tensor.rewriting.basic import register_canonicalize +@register_canonicalize +@node_rewriter([SplitDims]) +def local_split_dims_squeeze(fgraph, node): + """ + Canonicalize SplitDims Ops to Squeeze Ops if shape is (). + split_dims(x, axis=axis, shape=()) -> squeeze(x, axis) + """ + x, shape = node.inputs + axis = node.op.axis + + if isinstance(shape, Constant) and shape.data.size == 0: + squeezed_x = squeeze(x, axis=axis) + copy_stack_trace(x, squeezed_x) + return [squeezed_x] + return None + + +# @register_canonicalize +# @node_rewriter([SplitDims]) +# def local_split_dims_specify_shape(fgraph, node): +# """ +# Canonicalize SplitDims Ops to SpecifyShape Ops if shape is (dim,). +# split_dims(x, axis=axis, shape=(dim,)) -> specify_shape(x, (*[None] * axis, dim, ...)) +# """ +# x, shape = node.inputs +# axis = node.op.axis + +# if isinstance(shape, Constant) and shape.data.size == 1: +# # Extract the dimension value (shape.data is numpy array) +# dim_value = int(shape.data[0]) + +# # split_dims with shape=(dim,) keeps same number of dimensions +# output_ndim = x.type.ndim + +# specify_shape_x = specify_shape(x, (*[None] * axis, dim_value, ...)) +# copy_stack_trace(x, specify_shape_x) +# return [specify_shape_x] +# return None + + @register_canonicalize @node_rewriter([SplitDims]) def local_split_dims_to_reshape(fgraph, node): @@ -26,6 +68,39 @@ def local_split_dims_to_reshape(fgraph, node): return [new_x] +@register_canonicalize +@node_rewriter([JoinDims]) +def local_join_dims_noop(fgraph, node): + """ + Canonicalize JoinDims Ops to identity if n_axes=1. + join_dims(x, axis=axis, n_axes=1) -> x + """ + (x,) = node.inputs + op = node.op + + if op.n_axes == 1: + copy_stack_trace(x, x) + return [x] + return None + + +@register_canonicalize +@node_rewriter([JoinDims]) +def local_join_dims_expand(fgraph, node): + """ + Canonicalize JoinDims Ops to expand dims if n_axes=0. + join_dims(x, axis=axis, n_axes=0) -> expand_dims(x, axis) + """ + (x,) = node.inputs + op = node.op + + if op.n_axes == 0: + expanded_x = expand_dims(x, axis=node.op.start_axis) + copy_stack_trace(x, expanded_x) + return [expanded_x] + return None + + @register_canonicalize @node_rewriter([JoinDims]) def local_join_dims_to_reshape(fgraph, node): diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index d18e0b6419..3f84e2d18b 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -1,4 +1,7 @@ +import numpy as np + from pytensor.graph import FunctionGraph, rewrite_graph +from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims from pytensor.tensor.shape import Reshape from pytensor.tensor.type import tensor @@ -32,3 +35,99 @@ def test_local_join_dims_to_reshape(): assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 assert fg.outputs[0].type.shape == (2, 10, 3) + + +def test_local_join_dims_noop(): + """Test that join_dims with n_axes=1 becomes identity (no-op).""" + x = tensor("x", shape=(2, 3, 4)) + x_join = join_dims(x, start_axis=1, n_axes=1) + + fg = FunctionGraph(inputs=[x], outputs=[x_join]) + + # Before rewrite: should have 1 JoinDims node + assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1 + + rewrite_graph(fg, include=("canonicalize",)) + + # After rewrite: should have 0 JoinDims nodes + assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 + # Output should be equivalent to input (identity rewrite) + # The rewrite returns the input variable, so output should match input shape/type + assert fg.outputs[0].type.shape == x.type.shape + assert fg.outputs[0].type.dtype == x.type.dtype + assert fg.outputs[0].type.ndim == x.type.ndim + + +def test_local_join_dims_expand(): + """Test that join_dims with n_axes=0 becomes expand_dims.""" + x = tensor("x", shape=(2, 3, 4)) + x_join = join_dims(x, start_axis=1, n_axes=0) + + fg = FunctionGraph(inputs=[x], outputs=[x_join]) + + # Before rewrite: should have 1 JoinDims node + assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1 + + rewrite_graph(fg, include=("canonicalize",)) + + # After rewrite: should have 0 JoinDims nodes + assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 + # Should have 1 DimShuffle node with is_expand_dims=True + expand_nodes = [ + node + for node in fg.toposort() + if isinstance(node.op, DimShuffle) and node.op.is_expand_dims + ] + assert len(expand_nodes) == 1 + # Output shape should be (2, 1, 3, 4) - new dimension of size 1 inserted at axis 1 + assert fg.outputs[0].type.shape == (2, 1, 3, 4) + + +def test_local_split_dims_squeeze(): + """Test that split_dims with shape=() becomes squeeze.""" + x = tensor("x", shape=(2, 1, 3, 4)) + # Create a constant empty shape array - split_dims will convert it to a tensor + empty_shape = np.array([], dtype="int32") + x_split = split_dims(x, axis=1, shape=empty_shape) + + fg = FunctionGraph(inputs=[x], outputs=[x_split]) + + # Before rewrite: should have 1 SplitDims node + assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1 + + rewrite_graph(fg, include=("canonicalize",)) + + # After rewrite: should have 0 SplitDims nodes + assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 + # Should have 1 DimShuffle node with is_squeeze=True + squeeze_nodes = [ + node + for node in fg.toposort() + if isinstance(node.op, DimShuffle) and node.op.is_squeeze + ] + assert len(squeeze_nodes) == 1 + # Output shape should be (2, 3, 4) - dimension 1 removed + assert fg.outputs[0].type.shape == (2, 3, 4) + + +# def test_local_split_dims_specify_shape(): +# """Test that split_dims with shape=(dim,) becomes specify_shape (when input shape is None).""" +# # Create input with unknown shape at axis 1 +# x = tensor("x", shape=(2, None, 4)) +# # Create a constant shape with single dimension - split_dims will convert it to a tensor +# dim_shape = np.array([5], dtype="int32") +# x_split = split_dims(x, axis=1, shape=dim_shape) +# +# fg = FunctionGraph(inputs=[x], outputs=[x_split]) +# +# # Before rewrite: should have 1 SplitDims node +# assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1 +# +# rewrite_graph(fg, include=("canonicalize",)) +# +# # After rewrite: should have 0 SplitDims nodes +# assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 +# # Should have 1 SpecifyShape node +# assert sum([1 for node in fg.toposort() if isinstance(node.op, SpecifyShape)]) == 1 +# # Output shape should be (2, 5, 4) - dimension 1 specified as 5 +# assert fg.outputs[0].type.shape == (2, 5, 4) From 9a5e536122383bb209e570377720f8e096da6633 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Wed, 14 Jan 2026 14:33:17 -0500 Subject: [PATCH 02/12] removed constant; merged func --- pytensor/tensor/rewriting/reshape.py | 26 ++++++++------------------ tests/tensor/rewriting/test_reshape.py | 19 ++++++++++--------- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/pytensor/tensor/rewriting/reshape.py b/pytensor/tensor/rewriting/reshape.py index 38b12e2745..42b8a52ae9 100644 --- a/pytensor/tensor/rewriting/reshape.py +++ b/pytensor/tensor/rewriting/reshape.py @@ -1,4 +1,4 @@ -from pytensor.graph import Constant, node_rewriter +from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace from pytensor.tensor.basic import expand_dims from pytensor.tensor.extra_ops import squeeze @@ -6,23 +6,6 @@ from pytensor.tensor.rewriting.basic import register_canonicalize -@register_canonicalize -@node_rewriter([SplitDims]) -def local_split_dims_squeeze(fgraph, node): - """ - Canonicalize SplitDims Ops to Squeeze Ops if shape is (). - split_dims(x, axis=axis, shape=()) -> squeeze(x, axis) - """ - x, shape = node.inputs - axis = node.op.axis - - if isinstance(shape, Constant) and shape.data.size == 0: - squeezed_x = squeeze(x, axis=axis) - copy_stack_trace(x, squeezed_x) - return [squeezed_x] - return None - - # @register_canonicalize # @node_rewriter([SplitDims]) # def local_split_dims_specify_shape(fgraph, node): @@ -51,11 +34,18 @@ def local_split_dims_squeeze(fgraph, node): def local_split_dims_to_reshape(fgraph, node): """ Canonicalize SplitDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends). + Special case: if shape is (0,), converts to squeeze instead. """ x, shape = node.inputs axis = node.op.axis + # Special case: empty shape -> squeeze + if shape.type.shape == (0,): + squeezed_x = squeeze(x, axis=axis) + copy_stack_trace(x, squeezed_x) + return [squeezed_x] + output_shape = [ *x.shape[:axis], *shape, diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 3f84e2d18b..41979172bc 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -1,5 +1,3 @@ -import numpy as np - from pytensor.graph import FunctionGraph, rewrite_graph from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims @@ -8,6 +6,7 @@ def test_local_split_dims_to_reshape(): + """Test that split_dims converts to reshape for general shapes.""" x = tensor("x", shape=(2, 10, 3)) x_split = split_dims(x, axis=1, shape=(2, 5, 1)) @@ -83,14 +82,14 @@ def test_local_join_dims_expand(): assert fg.outputs[0].type.shape == (2, 1, 3, 4) -def test_local_split_dims_squeeze(): - """Test that split_dims with shape=() becomes squeeze.""" +def test_local_split_dims_to_reshape_squeeze_case(): + """Test that split_dims with shape tensor of static shape (0,) becomes squeeze via merged rewrite.""" x = tensor("x", shape=(2, 1, 3, 4)) - # Create a constant empty shape array - split_dims will convert it to a tensor - empty_shape = np.array([], dtype="int32") - x_split = split_dims(x, axis=1, shape=empty_shape) + # Create a tensor variable with static shape (0,) + empty_shape_var = tensor("empty_shape", shape=(0,), dtype="int32") + x_split = split_dims(x, axis=1, shape=empty_shape_var) - fg = FunctionGraph(inputs=[x], outputs=[x_split]) + fg = FunctionGraph(inputs=[x, empty_shape_var], outputs=[x_split]) # Before rewrite: should have 1 SplitDims node assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1 @@ -99,13 +98,15 @@ def test_local_split_dims_squeeze(): # After rewrite: should have 0 SplitDims nodes assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 - # Should have 1 DimShuffle node with is_squeeze=True + # Should have 1 DimShuffle node with is_squeeze=True (not Reshape) squeeze_nodes = [ node for node in fg.toposort() if isinstance(node.op, DimShuffle) and node.op.is_squeeze ] assert len(squeeze_nodes) == 1 + # Should NOT have a Reshape node + assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 0 # Output shape should be (2, 3, 4) - dimension 1 removed assert fg.outputs[0].type.shape == (2, 3, 4) From 273ecc3434cef8a06222d351c6cd591ee0b465a0 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Wed, 14 Jan 2026 14:39:26 -0500 Subject: [PATCH 03/12] removed specify shape --- pytensor/tensor/rewriting/reshape.py | 29 ++++++-------------------- tests/tensor/rewriting/test_reshape.py | 23 -------------------- 2 files changed, 6 insertions(+), 46 deletions(-) diff --git a/pytensor/tensor/rewriting/reshape.py b/pytensor/tensor/rewriting/reshape.py index 42b8a52ae9..8919ec0c0b 100644 --- a/pytensor/tensor/rewriting/reshape.py +++ b/pytensor/tensor/rewriting/reshape.py @@ -6,29 +6,6 @@ from pytensor.tensor.rewriting.basic import register_canonicalize -# @register_canonicalize -# @node_rewriter([SplitDims]) -# def local_split_dims_specify_shape(fgraph, node): -# """ -# Canonicalize SplitDims Ops to SpecifyShape Ops if shape is (dim,). -# split_dims(x, axis=axis, shape=(dim,)) -> specify_shape(x, (*[None] * axis, dim, ...)) -# """ -# x, shape = node.inputs -# axis = node.op.axis - -# if isinstance(shape, Constant) and shape.data.size == 1: -# # Extract the dimension value (shape.data is numpy array) -# dim_value = int(shape.data[0]) - -# # split_dims with shape=(dim,) keeps same number of dimensions -# output_ndim = x.type.ndim - -# specify_shape_x = specify_shape(x, (*[None] * axis, dim_value, ...)) -# copy_stack_trace(x, specify_shape_x) -# return [specify_shape_x] -# return None - - @register_canonicalize @node_rewriter([SplitDims]) def local_split_dims_to_reshape(fgraph, node): @@ -40,6 +17,12 @@ def local_split_dims_to_reshape(fgraph, node): x, shape = node.inputs axis = node.op.axis + # Special case: empty shape -> squeeze + if shape.type.shape == (0,): + squeezed_x = squeeze(x, axis=axis) + copy_stack_trace(x, squeezed_x) + return [squeezed_x] + # Special case: empty shape -> squeeze if shape.type.shape == (0,): squeezed_x = squeeze(x, axis=axis) diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 41979172bc..1921201338 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -109,26 +109,3 @@ def test_local_split_dims_to_reshape_squeeze_case(): assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 0 # Output shape should be (2, 3, 4) - dimension 1 removed assert fg.outputs[0].type.shape == (2, 3, 4) - - -# def test_local_split_dims_specify_shape(): -# """Test that split_dims with shape=(dim,) becomes specify_shape (when input shape is None).""" -# # Create input with unknown shape at axis 1 -# x = tensor("x", shape=(2, None, 4)) -# # Create a constant shape with single dimension - split_dims will convert it to a tensor -# dim_shape = np.array([5], dtype="int32") -# x_split = split_dims(x, axis=1, shape=dim_shape) -# -# fg = FunctionGraph(inputs=[x], outputs=[x_split]) -# -# # Before rewrite: should have 1 SplitDims node -# assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1 -# -# rewrite_graph(fg, include=("canonicalize",)) -# -# # After rewrite: should have 0 SplitDims nodes -# assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 -# # Should have 1 SpecifyShape node -# assert sum([1 for node in fg.toposort() if isinstance(node.op, SpecifyShape)]) == 1 -# # Output shape should be (2, 5, 4) - dimension 1 specified as 5 -# assert fg.outputs[0].type.shape == (2, 5, 4) From 0037dfa9c7986c774695943f330f71912c756852 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Fri, 16 Jan 2026 14:28:58 -0500 Subject: [PATCH 04/12] merged functions; updated test --- pytensor/tensor/rewriting/reshape.py | 48 +++++--------------------- tests/tensor/rewriting/test_reshape.py | 6 ++-- 2 files changed, 11 insertions(+), 43 deletions(-) diff --git a/pytensor/tensor/rewriting/reshape.py b/pytensor/tensor/rewriting/reshape.py index 8919ec0c0b..3fc9372f14 100644 --- a/pytensor/tensor/rewriting/reshape.py +++ b/pytensor/tensor/rewriting/reshape.py @@ -8,7 +8,7 @@ @register_canonicalize @node_rewriter([SplitDims]) -def local_split_dims_to_reshape(fgraph, node): +def local_split_dims(fgraph, node): """ Canonicalize SplitDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends). Special case: if shape is (0,), converts to squeeze instead. @@ -17,12 +17,6 @@ def local_split_dims_to_reshape(fgraph, node): x, shape = node.inputs axis = node.op.axis - # Special case: empty shape -> squeeze - if shape.type.shape == (0,): - squeezed_x = squeeze(x, axis=axis) - copy_stack_trace(x, squeezed_x) - return [squeezed_x] - # Special case: empty shape -> squeeze if shape.type.shape == (0,): squeezed_x = squeeze(x, axis=axis) @@ -43,48 +37,24 @@ def local_split_dims_to_reshape(fgraph, node): @register_canonicalize @node_rewriter([JoinDims]) -def local_join_dims_noop(fgraph, node): +def local_join_dims(fgraph, node): """ - Canonicalize JoinDims Ops to identity if n_axes=1. - join_dims(x, axis=axis, n_axes=1) -> x + Canonicalize JoinDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends). """ - (x,) = node.inputs - op = node.op - - if op.n_axes == 1: - copy_stack_trace(x, x) - return [x] - return None - -@register_canonicalize -@node_rewriter([JoinDims]) -def local_join_dims_expand(fgraph, node): - """ - Canonicalize JoinDims Ops to expand dims if n_axes=0. - join_dims(x, axis=axis, n_axes=0) -> expand_dims(x, axis) - """ (x,) = node.inputs op = node.op + start_axis = op.start_axis + n_axes = op.n_axes - if op.n_axes == 0: + if n_axes == 0: expanded_x = expand_dims(x, axis=node.op.start_axis) copy_stack_trace(x, expanded_x) return [expanded_x] - return None - - -@register_canonicalize -@node_rewriter([JoinDims]) -def local_join_dims_to_reshape(fgraph, node): - """ - Canonicalize JoinDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends). - """ - (x,) = node.inputs - op = node.op - start_axis = op.start_axis - n_axes = op.n_axes + if n_axes == 1: + copy_stack_trace(x, x) + return [x] output_shape = [ *x.shape[:start_axis], diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 1921201338..4753cfc5b7 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -3,6 +3,7 @@ from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims from pytensor.tensor.shape import Reshape from pytensor.tensor.type import tensor +from tests.unittest_tools import assert_equal_computations def test_local_split_dims_to_reshape(): @@ -51,10 +52,7 @@ def test_local_join_dims_noop(): # After rewrite: should have 0 JoinDims nodes assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 # Output should be equivalent to input (identity rewrite) - # The rewrite returns the input variable, so output should match input shape/type - assert fg.outputs[0].type.shape == x.type.shape - assert fg.outputs[0].type.dtype == x.type.dtype - assert fg.outputs[0].type.ndim == x.type.ndim + assert_equal_computations([fg.outputs[0]], [x], in_xs=[fg.outputs[0]], in_ys=[x]) def test_local_join_dims_expand(): From cacda87ab9ba848a08625ed8353ce653bef275bc Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Fri, 16 Jan 2026 14:55:43 -0500 Subject: [PATCH 05/12] updated tests --- tests/tensor/rewriting/test_reshape.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 4753cfc5b7..20ea72e663 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -1,12 +1,14 @@ from pytensor.graph import FunctionGraph, rewrite_graph +from pytensor.tensor.basic import expand_dims from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims from pytensor.tensor.shape import Reshape from pytensor.tensor.type import tensor from tests.unittest_tools import assert_equal_computations -def test_local_split_dims_to_reshape(): +def test_local_split_dims(): """Test that split_dims converts to reshape for general shapes.""" x = tensor("x", shape=(2, 10, 3)) x_split = split_dims(x, axis=1, shape=(2, 5, 1)) @@ -22,7 +24,7 @@ def test_local_split_dims_to_reshape(): assert fg.outputs[0].type.shape == (2, 2, 5, 1, 3) -def test_local_join_dims_to_reshape(): +def test_local_join_dims(): x = tensor("x", shape=(2, 2, 5, 1, 3)) x_join = join_dims(x, start_axis=1, n_axes=3) @@ -31,11 +33,13 @@ def test_local_join_dims_to_reshape(): assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1 rewrite_graph(fg, include=("canonicalize",)) - assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 assert fg.outputs[0].type.shape == (2, 10, 3) + # expected = x.reshape((2, 10, 3)) + # assert_equal_computations([fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected]) + def test_local_join_dims_noop(): """Test that join_dims with n_axes=1 becomes identity (no-op).""" @@ -77,7 +81,10 @@ def test_local_join_dims_expand(): ] assert len(expand_nodes) == 1 # Output shape should be (2, 1, 3, 4) - new dimension of size 1 inserted at axis 1 - assert fg.outputs[0].type.shape == (2, 1, 3, 4) + expected = expand_dims(x, axis=1) + assert_equal_computations( + [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] + ) def test_local_split_dims_to_reshape_squeeze_case(): @@ -106,4 +113,7 @@ def test_local_split_dims_to_reshape_squeeze_case(): # Should NOT have a Reshape node assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 0 # Output shape should be (2, 3, 4) - dimension 1 removed - assert fg.outputs[0].type.shape == (2, 3, 4) + expected = squeeze(x, axis=1) + assert_equal_computations( + [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] + ) From 69943d57f560e3ab6fc8ca7f7fd47360715ea8a0 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Wed, 4 Feb 2026 08:31:29 -0500 Subject: [PATCH 06/12] fixed splitdims test; updated function def --- pytensor/tensor/rewriting/reshape.py | 4 +-- tests/tensor/rewriting/test_reshape.py | 41 ++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/rewriting/reshape.py b/pytensor/tensor/rewriting/reshape.py index 3fc9372f14..facb825216 100644 --- a/pytensor/tensor/rewriting/reshape.py +++ b/pytensor/tensor/rewriting/reshape.py @@ -24,9 +24,9 @@ def local_split_dims(fgraph, node): return [squeezed_x] output_shape = [ - *x.shape[:axis], + *[x.shape[i] for i in range(axis)], *shape, - *x.shape[axis + 1 :], + *[x.shape[i] for i in range(axis + 1, x.type.ndim)], ] new_x = x.reshape(output_shape) diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 20ea72e663..c154210f0d 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -1,9 +1,13 @@ +import numpy as np + +import pytensor as pt from pytensor.graph import FunctionGraph, rewrite_graph +from pytensor.tensor import shape from pytensor.tensor.basic import expand_dims from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims -from pytensor.tensor.shape import Reshape +from pytensor.tensor.shape import Reshape, specify_shape from pytensor.tensor.type import tensor from tests.unittest_tools import assert_equal_computations @@ -17,11 +21,37 @@ def test_local_split_dims(): assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1 - rewrite_graph(fg, include=("canonicalize",)) + with pt.config.change_flags(optimizer_verbose=True): + rewrite_graph( + fg, + include=("canonicalize",), + exclude=( + "local_subtensor_merge", + "local_subtensor_remove_broadcastable_index", + ), + ) assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 - assert fg.outputs[0].type.shape == (2, 2, 5, 1, 3) + + # Build the expected computation manually + x_shape = shape(x) + + # 1. Build shape vector for reshape: (x.shape[0], 2, 5, x.shape[2]) = (2, 2, 5, 3) + # The split shape (2, 5, 1) has the 1 removed for reshape, then expand_dims adds it back + shape_vector = pt.tensor.stack([x_shape[0], np.int64(2), np.int64(5), x_shape[2]]) + + # 2. Replicate the Reshape and ExpandDims + reshaped = Reshape(4)(x, shape_vector) + expanded = expand_dims(reshaped, axis=3) + + # 3. SpecifyShape to lock in the output shape + expected_shape_tuple = (2, 2, 5, 1, 3) + expected = specify_shape(expanded, expected_shape_tuple) + + assert_equal_computations( + [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] + ) def test_local_join_dims(): @@ -33,10 +63,15 @@ def test_local_join_dims(): assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1 rewrite_graph(fg, include=("canonicalize",)) + assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 assert fg.outputs[0].type.shape == (2, 10, 3) + expected = x.reshape((2, 10, 3)) + assert_equal_computations( + [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] + ) # expected = x.reshape((2, 10, 3)) # assert_equal_computations([fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected]) From 1c77c0e3c1f664c9ff269cf3f80de1f53d73054e Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Wed, 4 Feb 2026 08:40:37 -0500 Subject: [PATCH 07/12] fixed join dims def and test --- pytensor/tensor/rewriting/reshape.py | 4 ++-- tests/tensor/rewriting/test_reshape.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/reshape.py b/pytensor/tensor/rewriting/reshape.py index facb825216..55e6f85048 100644 --- a/pytensor/tensor/rewriting/reshape.py +++ b/pytensor/tensor/rewriting/reshape.py @@ -57,9 +57,9 @@ def local_join_dims(fgraph, node): return [x] output_shape = [ - *x.shape[:start_axis], + *[x.shape[i] for i in range(start_axis)], -1, - *x.shape[start_axis + n_axes :], + *[x.shape[i] for i in range(start_axis + n_axes, x.type.ndim)], ] new_x = x.reshape(output_shape) diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index c154210f0d..e4bdfdb0c9 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -62,7 +62,14 @@ def test_local_join_dims(): assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1 - rewrite_graph(fg, include=("canonicalize",)) + rewrite_graph( + fg, + include=("canonicalize",), + exclude=( + "local_subtensor_merge", + "local_subtensor_remove_broadcastable_index", + ), + ) assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 From 104ab1a87f5836f7e709fcc13a034f5d6dd330a0 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Fri, 6 Feb 2026 10:09:34 -0500 Subject: [PATCH 08/12] removed redundant assertion in test --- tests/tensor/rewriting/test_reshape.py | 49 +++++--------------------- 1 file changed, 8 insertions(+), 41 deletions(-) diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index e4bdfdb0c9..9d3c19b19a 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -4,7 +4,6 @@ from pytensor.graph import FunctionGraph, rewrite_graph from pytensor.tensor import shape from pytensor.tensor.basic import expand_dims -from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims from pytensor.tensor.shape import Reshape, specify_shape @@ -19,20 +18,14 @@ def test_local_split_dims(): fg = FunctionGraph(inputs=[x], outputs=[x_split]) - assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1 - - with pt.config.change_flags(optimizer_verbose=True): - rewrite_graph( - fg, - include=("canonicalize",), - exclude=( - "local_subtensor_merge", - "local_subtensor_remove_broadcastable_index", - ), - ) - - assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 - assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 + rewrite_graph( + fg, + include=("canonicalize",), + exclude=( + "local_subtensor_merge", + "local_subtensor_remove_broadcastable_index", + ), + ) # Build the expected computation manually x_shape = shape(x) @@ -60,8 +53,6 @@ def test_local_join_dims(): fg = FunctionGraph(inputs=[x], outputs=[x_join]) - assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1 - rewrite_graph( fg, include=("canonicalize",), @@ -71,10 +62,6 @@ def test_local_join_dims(): ), ) - assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 - assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 - assert fg.outputs[0].type.shape == (2, 10, 3) - expected = x.reshape((2, 10, 3)) assert_equal_computations( [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] @@ -113,15 +100,6 @@ def test_local_join_dims_expand(): rewrite_graph(fg, include=("canonicalize",)) - # After rewrite: should have 0 JoinDims nodes - assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 - # Should have 1 DimShuffle node with is_expand_dims=True - expand_nodes = [ - node - for node in fg.toposort() - if isinstance(node.op, DimShuffle) and node.op.is_expand_dims - ] - assert len(expand_nodes) == 1 # Output shape should be (2, 1, 3, 4) - new dimension of size 1 inserted at axis 1 expected = expand_dims(x, axis=1) assert_equal_computations( @@ -143,17 +121,6 @@ def test_local_split_dims_to_reshape_squeeze_case(): rewrite_graph(fg, include=("canonicalize",)) - # After rewrite: should have 0 SplitDims nodes - assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 - # Should have 1 DimShuffle node with is_squeeze=True (not Reshape) - squeeze_nodes = [ - node - for node in fg.toposort() - if isinstance(node.op, DimShuffle) and node.op.is_squeeze - ] - assert len(squeeze_nodes) == 1 - # Should NOT have a Reshape node - assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 0 # Output shape should be (2, 3, 4) - dimension 1 removed expected = squeeze(x, axis=1) assert_equal_computations( From e26249c5af811c77b597858a7c092101233750cb Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Fri, 6 Feb 2026 13:14:09 -0500 Subject: [PATCH 09/12] restored changes to reshape and updated test --- pytensor/tensor/rewriting/reshape.py | 8 +++--- tests/tensor/rewriting/test_reshape.py | 35 +++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/pytensor/tensor/rewriting/reshape.py b/pytensor/tensor/rewriting/reshape.py index 55e6f85048..3fc9372f14 100644 --- a/pytensor/tensor/rewriting/reshape.py +++ b/pytensor/tensor/rewriting/reshape.py @@ -24,9 +24,9 @@ def local_split_dims(fgraph, node): return [squeezed_x] output_shape = [ - *[x.shape[i] for i in range(axis)], + *x.shape[:axis], *shape, - *[x.shape[i] for i in range(axis + 1, x.type.ndim)], + *x.shape[axis + 1 :], ] new_x = x.reshape(output_shape) @@ -57,9 +57,9 @@ def local_join_dims(fgraph, node): return [x] output_shape = [ - *[x.shape[i] for i in range(start_axis)], + *x.shape[:start_axis], -1, - *[x.shape[i] for i in range(start_axis + n_axes, x.type.ndim)], + *x.shape[start_axis + n_axes :], ] new_x = x.reshape(output_shape) diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 9d3c19b19a..4e426032fc 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -30,9 +30,18 @@ def test_local_split_dims(): # Build the expected computation manually x_shape = shape(x) - # 1. Build shape vector for reshape: (x.shape[0], 2, 5, x.shape[2]) = (2, 2, 5, 3) - # The split shape (2, 5, 1) has the 1 removed for reshape, then expand_dims adds it back - shape_vector = pt.tensor.stack([x_shape[0], np.int64(2), np.int64(5), x_shape[2]]) + # Use slice notation like the code does: x.shape[:axis] and x.shape[axis+1:] + shape_before = x_shape[:1] # Creates Subtensor{:stop} + shape_after = x_shape[2:] # Creates Subtensor{start:} + + # Concatenate: [x.shape[:1], 2, 5, x.shape[2:]] + shape_vector = pt.tensor.concatenate( + [ + shape_before, + pt.tensor.as_tensor([np.int64(2), np.int64(5)]), + shape_after, + ] + ) # 2. Replicate the Reshape and ExpandDims reshaped = Reshape(4)(x, shape_vector) @@ -62,7 +71,25 @@ def test_local_join_dims(): ), ) - expected = x.reshape((2, 10, 3)) + x_shape = shape(x) + shape_before = x_shape[:1] # x.shape[:start_axis] + shape_after = x_shape[4:] # x.shape[start_axis + n_axes:] + + shape_vector = pt.tensor.concatenate( + [ + shape_before, + pt.tensor.as_tensor([np.int64(-1)]), + shape_after, + ] + ) + + # Squeeze then reshape + squeezed = squeeze(x, axis=3) + reshaped = Reshape(3)(squeezed, shape_vector) + + # SpecifyShape to lock in output + expected = specify_shape(reshaped, (2, 10, 3)) + assert_equal_computations( [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] ) From a268570ae5f5b28311ed131751bbd09f60920960 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Thu, 12 Feb 2026 09:17:58 -0500 Subject: [PATCH 10/12] cleaned up tests --- tests/tensor/rewriting/test_reshape.py | 86 ++++---------------------- 1 file changed, 13 insertions(+), 73 deletions(-) diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 4e426032fc..85512ca0ae 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -1,100 +1,40 @@ -import numpy as np - -import pytensor as pt from pytensor.graph import FunctionGraph, rewrite_graph -from pytensor.tensor import shape from pytensor.tensor.basic import expand_dims from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.reshape import JoinDims, SplitDims, join_dims, split_dims -from pytensor.tensor.shape import Reshape, specify_shape +from pytensor.tensor.shape import Reshape from pytensor.tensor.type import tensor from tests.unittest_tools import assert_equal_computations -def test_local_split_dims(): - """Test that split_dims converts to reshape for general shapes.""" +def test_local_split_dims_to_reshape(): x = tensor("x", shape=(2, 10, 3)) x_split = split_dims(x, axis=1, shape=(2, 5, 1)) fg = FunctionGraph(inputs=[x], outputs=[x_split]) - rewrite_graph( - fg, - include=("canonicalize",), - exclude=( - "local_subtensor_merge", - "local_subtensor_remove_broadcastable_index", - ), - ) - - # Build the expected computation manually - x_shape = shape(x) - - # Use slice notation like the code does: x.shape[:axis] and x.shape[axis+1:] - shape_before = x_shape[:1] # Creates Subtensor{:stop} - shape_after = x_shape[2:] # Creates Subtensor{start:} - - # Concatenate: [x.shape[:1], 2, 5, x.shape[2:]] - shape_vector = pt.tensor.concatenate( - [ - shape_before, - pt.tensor.as_tensor([np.int64(2), np.int64(5)]), - shape_after, - ] - ) - - # 2. Replicate the Reshape and ExpandDims - reshaped = Reshape(4)(x, shape_vector) - expanded = expand_dims(reshaped, axis=3) + assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1 - # 3. SpecifyShape to lock in the output shape - expected_shape_tuple = (2, 2, 5, 1, 3) - expected = specify_shape(expanded, expected_shape_tuple) + rewrite_graph(fg, include=("canonicalize",)) - assert_equal_computations( - [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] - ) + assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 + assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 + assert fg.outputs[0].type.shape == (2, 2, 5, 1, 3) -def test_local_join_dims(): +def test_local_join_dims_to_reshape(): x = tensor("x", shape=(2, 2, 5, 1, 3)) x_join = join_dims(x, start_axis=1, n_axes=3) fg = FunctionGraph(inputs=[x], outputs=[x_join]) - rewrite_graph( - fg, - include=("canonicalize",), - exclude=( - "local_subtensor_merge", - "local_subtensor_remove_broadcastable_index", - ), - ) - - x_shape = shape(x) - shape_before = x_shape[:1] # x.shape[:start_axis] - shape_after = x_shape[4:] # x.shape[start_axis + n_axes:] - - shape_vector = pt.tensor.concatenate( - [ - shape_before, - pt.tensor.as_tensor([np.int64(-1)]), - shape_after, - ] - ) - - # Squeeze then reshape - squeezed = squeeze(x, axis=3) - reshaped = Reshape(3)(squeezed, shape_vector) + assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 1 - # SpecifyShape to lock in output - expected = specify_shape(reshaped, (2, 10, 3)) + rewrite_graph(fg, include=("canonicalize",)) - assert_equal_computations( - [fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected] - ) - # expected = x.reshape((2, 10, 3)) - # assert_equal_computations([fg.outputs[0]], [expected], in_xs=[fg.outputs[0]], in_ys=[expected]) + assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 + assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 + assert fg.outputs[0].type.shape == (2, 10, 3) def test_local_join_dims_noop(): From 436d17966038931e316be9d21646d6d6998d9bc2 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Thu, 12 Feb 2026 09:30:41 -0500 Subject: [PATCH 11/12] cleaned up redundant asserts --- tests/tensor/rewriting/test_reshape.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 85512ca0ae..7d2694f3d3 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -17,8 +17,6 @@ def test_local_split_dims_to_reshape(): rewrite_graph(fg, include=("canonicalize",)) - assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 - assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 assert fg.outputs[0].type.shape == (2, 2, 5, 1, 3) @@ -49,8 +47,6 @@ def test_local_join_dims_noop(): rewrite_graph(fg, include=("canonicalize",)) - # After rewrite: should have 0 JoinDims nodes - assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 # Output should be equivalent to input (identity rewrite) assert_equal_computations([fg.outputs[0]], [x], in_xs=[fg.outputs[0]], in_ys=[x]) From 80b45b919aa6841abc5273738e7a9335342cae13 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Thu, 12 Feb 2026 09:53:38 -0500 Subject: [PATCH 12/12] restored original test --- tests/tensor/rewriting/test_reshape.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 7d2694f3d3..2d761875fb 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -17,6 +17,8 @@ def test_local_split_dims_to_reshape(): rewrite_graph(fg, include=("canonicalize",)) + assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 + assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 assert fg.outputs[0].type.shape == (2, 2, 5, 1, 3)