From f024f96acc78b3fdb5122a57ef7f1a31c78130a3 Mon Sep 17 00:00:00 2001 From: Jasjeet-Singh-S Date: Sun, 1 Feb 2026 22:43:38 +0530 Subject: [PATCH 1/3] Add rewrite for argmax/argmin of monotonic functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements graph rewrite that eliminates redundant monotonic function applications in argmax/argmin operations. For monotonically increasing functions, rewrites argmax(f(x)) → argmax(x) and argmin(f(x)) → argmin(x). For decreasing functions, flips operations: argmax(f(x)) → argmin(x) and argmin(f(x)) → argmax(x). Includes comprehensive tests. --- pytensor/tensor/rewriting/math.py | 88 ++++++++++++++++++++ tests/tensor/rewriting/test_math.py | 119 ++++++++++++++++++++++++++++ 2 files changed, 207 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a9bc7b15cb..f89089f0e9 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -41,12 +41,14 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast from pytensor.tensor.math import ( + Argmax, Dot, Prod, Sum, _conj, _dot, _matmul, + argmin, add, arccosh, arcsinh, @@ -121,6 +123,14 @@ TensorVariable, ) +MONOTONIC_INCREASING = ( + ps.Exp, ps.Exp2, ps.Expm1, ps.Log, ps.Log2, ps.Log10, ps.Log1p, + ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan, + ps.ArcCosh, ps.Sinh, ps.ArcSinh, ps.Tanh, ps.ArcTanh +) + +MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos) + def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): """Partition a list of variables into two kinds: @@ -3885,3 +3895,81 @@ def local_useless_conj(fgraph, node): ) register_stabilize(local_log_kv) + +def _is_argmin(node): + """Check if node represents argmin by detecting Argmax(Neg(...))""" + if not isinstance(node.op, Argmax): + return False + + input_node = node.inputs[0] + if not input_node.owner: + return False + + # argmin(x) becomes Argmax(Neg(x)) or Argmax(imax - x) or Argmax(~x) + inner_op = input_node.owner.op + if isinstance(inner_op, Elemwise) and isinstance(inner_op.scalar_op, ps.Neg): + return True + + return False + +@register_canonicalize +@node_rewriter([Argmax]) +def local_argmax_argmin_monotonic(fgraph, node): + """ + Optimize argmax/argmin with monotonic functions: + - argmax(f_inc(x)) -> argmax(x) for monotonically increasing f + - argmin(f_inc(x)) -> argmin(x) for monotonically increasing f + - argmax(f_dec(x)) -> argmin(x) for monotonically decreasing f + - argmin(f_dec(x)) -> argmax(x) for monotonically decreasing f + + Note: argmin is represented as Argmax(Neg(...)) internally + """ + + if not isinstance(node.op, Argmax): + return False + + is_argmin = _is_argmin(node) + argmax_input = node.inputs[0] + + # If argmin, skip the Neg wrapper to get to the monotonic function + if is_argmin: + if not argmax_input.owner: + return False + argmax_input = argmax_input.owner.inputs[0] # Skip Neg + + if not argmax_input.owner: + return False + + inner_op = argmax_input.owner.op + + if not isinstance(inner_op, Elemwise): + return False + + scalar_op = inner_op.scalar_op + + is_increasing = isinstance(scalar_op, MONOTONIC_INCREASING) + is_decreasing = isinstance(scalar_op, MONOTONIC_DECREASING) + + if not (is_increasing or is_decreasing): + return False + + x = argmax_input.owner.inputs[0] + + # Determine new operation based on current op and monotonicity + if is_argmin: + if is_increasing: + # argmin(f_inc(x)) -> argmin(x) = Argmax(Neg(x)) + new_output = argmin(x, axis=node.op.axis) + else: # is_decreasing + # argmin(f_dec(x)) -> argmax(x) + new_output = node.op(x) + else: # is argmax + if is_increasing: + # argmax(f_inc(x)) -> argmax(x) + new_output = node.op(x) + else: # is_decreasing + # argmax(f_dec(x)) -> argmin(x) = Argmax(Neg(x)) + new_output = argmin(x, axis=node.op.axis) + + copy_stack_trace(node.outputs[0], new_output) + return [new_output] \ No newline at end of file diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index bf2160aaf1..c2a1c0bb67 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -5022,3 +5022,122 @@ def test_benchmark(self, benchmark, size, rewrite): c_val, d_val, ) + +class TestArgmaxArgminMonotonic: + """Test argmax/argmin rewrites for monotonic functions.""" + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_argmax_increasing_functions(self, axis): + """Test argmax(f_inc(x)) -> argmax(x) for monotonic increasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.exp, pt.log1p, pt.sqrt]: + # Compile the unrewritten and expected graphs + unrewritten = pt.argmax(f(x), axis=axis) + expected = pt.argmax(x, axis=axis) + + # Create functions to apply rewrites + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + # Test numerical equality + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + assert result_unrewritten == result_expected, ( + f"argmax({f.__name__}(x), axis={axis}) failed: " + f"got {result_unrewritten}, expected {result_expected}" + ) + + # Verify the rewrite was applied (no Exp/Log1p/Sqrt in final graph) + topo = fn_unrewritten.maker.fgraph.toposort() + has_eliminated_op = any( + isinstance(node.op, Elemwise) and + isinstance(node.op.scalar_op, (ps.Exp, ps.Log1p, ps.Sqrt)) + for node in topo + ) + assert not has_eliminated_op, ( + f"Rewrite failed to eliminate {f.__name__} from argmax graph" + ) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_argmin_increasing_functions(self, axis): + """Test argmin(f_inc(x)) -> argmin(x) for monotonic increasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.exp, pt.log1p, pt.sqrt]: + unrewritten = pt.argmin(f(x), axis=axis) + expected = pt.argmin(x, axis=axis) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + assert result_unrewritten == result_expected, ( + f"argmin({f.__name__}(x), axis={axis}) failed: " + f"got {result_unrewritten}, expected {result_expected}" + ) + + topo = fn_unrewritten.maker.fgraph.toposort() + has_eliminated_op = any( + isinstance(node.op, Elemwise) and + isinstance(node.op.scalar_op, (ps.Exp, ps.Log1p, ps.Sqrt)) + for node in topo + ) + assert not has_eliminated_op, ( + f"Rewrite failed to eliminate {f.__name__} from argmin graph" + ) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_argmax_decreasing_functions(self, axis): + """Test argmax(f_dec(x)) -> argmin(x) for monotonic decreasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.neg, lambda z: -z]: + unrewritten = pt.argmax(f(x), axis=axis) + expected = pt.argmin(x, axis=axis) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + assert result_unrewritten == result_expected, ( + f"argmax(neg(x), axis={axis}) failed: " + f"got {result_unrewritten}, expected {result_expected}" + ) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_argmin_decreasing_functions(self, axis): + """Test argmin(f_dec(x)) -> argmax(x) for monotonic decreasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.neg, lambda z: -z]: + unrewritten = pt.argmin(f(x), axis=axis) + expected = pt.argmax(x, axis=axis) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + assert result_unrewritten == result_expected, ( + f"argmin(neg(x), axis={axis}) failed: " + f"got {result_unrewritten}, expected {result_expected}" + ) \ No newline at end of file From d3c6f16ecaac916eed7b5f848ca04714f3c61be3 Mon Sep 17 00:00:00 2001 From: Jasjeet-Singh-S Date: Tue, 3 Feb 2026 13:55:36 +0530 Subject: [PATCH 2/3] updated previous PR: --- pytensor/tensor/rewriting/math.py | 4 ++-- tests/tensor/rewriting/test_math.py | 36 +++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index f89089f0e9..ffb3728831 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -125,11 +125,11 @@ MONOTONIC_INCREASING = ( ps.Exp, ps.Exp2, ps.Expm1, ps.Log, ps.Log2, ps.Log10, ps.Log1p, - ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan, + ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.ArcTan, ps.ArcCosh, ps.Sinh, ps.ArcSinh, ps.Tanh, ps.ArcTanh ) -MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos) +MONOTONIC_DECREASING = (ps.Neg, ps.ArcCos) def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index c2a1c0bb67..3d7d8be2aa 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -5100,11 +5100,11 @@ def test_argmin_increasing_functions(self, axis): def test_argmax_decreasing_functions(self, axis): """Test argmax(f_dec(x)) -> argmin(x) for monotonic decreasing f.""" x = pt.vector("x") - test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + test_val = np.array([0.1, 0.3, 0.2, 0.5, 0.4]) # Values in (0, 1) for arccos mode = get_default_mode() - for f in [pt.neg, lambda z: -z]: + for f in [pt.arccos]: unrewritten = pt.argmax(f(x), axis=axis) expected = pt.argmin(x, axis=axis) @@ -5115,19 +5115,30 @@ def test_argmax_decreasing_functions(self, axis): result_expected = fn_expected(test_val) assert result_unrewritten == result_expected, ( - f"argmax(neg(x), axis={axis}) failed: " + f"argmax({f.__name__}(x), axis={axis}) failed: " f"got {result_unrewritten}, expected {result_expected}" ) - + + # Verify the rewrite was applied (no ArcCos in final graph) + topo = fn_unrewritten.maker.fgraph.toposort() + has_eliminated_op = any( + isinstance(node.op, Elemwise) and + isinstance(node.op.scalar_op, ps.ArcCos) # not testing for negative since argmin contains a neg itself + for node in topo + ) + assert not has_eliminated_op, ( + f"Rewrite failed to eliminate arccos from argmax graph" + ) + @pytest.mark.parametrize("axis", [None, 0, -1]) def test_argmin_decreasing_functions(self, axis): """Test argmin(f_dec(x)) -> argmax(x) for monotonic decreasing f.""" x = pt.vector("x") - test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + test_val = np.array([0.1, 0.3, 0.2, 0.5, 0.4]) # Values in (0, 1) for arccos mode = get_default_mode() - for f in [pt.neg, lambda z: -z]: + for f in [pt.arccos]: unrewritten = pt.argmin(f(x), axis=axis) expected = pt.argmax(x, axis=axis) @@ -5138,6 +5149,17 @@ def test_argmin_decreasing_functions(self, axis): result_expected = fn_expected(test_val) assert result_unrewritten == result_expected, ( - f"argmin(neg(x), axis={axis}) failed: " + f"argmin({f.__name__}(x), axis={axis}) failed: " f"got {result_unrewritten}, expected {result_expected}" + ) + + # Verify the rewrite was applied (no ArcCos in final graph) + topo = fn_unrewritten.maker.fgraph.toposort() + has_eliminated_op = any( + isinstance(node.op, Elemwise) and + isinstance(node.op.scalar_op, ps.ArcCos) + for node in topo + ) + assert not has_eliminated_op, ( + f"Rewrite failed to eliminate arccos from argmin graph" ) \ No newline at end of file From d7045835b180c7a02b8dac568e992463f27438a7 Mon Sep 17 00:00:00 2001 From: Jasjeet Singh Date: Mon, 9 Feb 2026 10:58:06 +0530 Subject: [PATCH 3/3] added max and min rewrite --- pytensor/tensor/rewriting/math.py | 58 ++++++++++++++++ tests/tensor/rewriting/test_math.py | 102 +++++++++++++++++++++++++++- 2 files changed, 158 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index ffb3728831..92b30926f1 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -42,6 +42,8 @@ from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast from pytensor.tensor.math import ( Argmax, + Max, + Min, Dot, Prod, Sum, @@ -3971,5 +3973,61 @@ def local_argmax_argmin_monotonic(fgraph, node): # argmax(f_dec(x)) -> argmin(x) = Argmax(Neg(x)) new_output = argmin(x, axis=node.op.axis) + copy_stack_trace(node.outputs[0], new_output) + return [new_output] + +@register_canonicalize +@node_rewriter([Max, Min]) +def local_max_min_monotonic(fgraph, node): + """ + Optimize max/min with monotonic functions by moving the function outside: + - max(f_inc(x)) -> f_inc(max(x)) for monotonically increasing f + - min(f_inc(x)) -> f_inc(min(x)) for monotonically increasing f + - max(f_dec(x)) -> f_dec(min(x)) for monotonically decreasing f + - min(f_dec(x)) -> f_dec(max(x)) for monotonically decreasing f + """ + if not isinstance(node.op, (Max, Min)): + return False + + is_max = isinstance(node.op, Max) + input_arg = node.inputs[0] + + if not input_arg.owner: + return False + + inner_op = input_arg.owner.op + + if not isinstance(inner_op, Elemwise): + return False + + scalar_op = inner_op.scalar_op + + is_increasing = isinstance(scalar_op, MONOTONIC_INCREASING) + is_decreasing = isinstance(scalar_op, MONOTONIC_DECREASING) + + if not (is_increasing or is_decreasing): + return False + + x = input_arg.owner.inputs[0] + + # Determine new operation based on current op and monotonicity + if is_max: + if is_increasing: + # max(f_inc(x)) -> f_inc(max(x)) + inner_result = node.op.make_node(x).outputs[0] + else: # is_decreasing + # max(f_dec(x)) -> f_dec(min(x)) + inner_result = Min(axis=node.op.axis)(x) + else: # is_min + if is_increasing: + # min(f_inc(x)) -> f_inc(min(x)) + inner_result = node.op.make_node(x).outputs[0] + else: # is_decreasing + # min(f_dec(x)) -> f_dec(max(x)) + inner_result = Max(axis=node.op.axis)(x) + + # Apply the monotonic function to the result + new_output = inner_op.make_node(inner_result).outputs[0] + copy_stack_trace(node.outputs[0], new_output) return [new_output] \ No newline at end of file diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 3d7d8be2aa..99b8d487f4 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -5023,7 +5023,7 @@ def test_benchmark(self, benchmark, size, rewrite): d_val, ) -class TestArgmaxArgminMonotonic: +class TestArgmaxArgminMaxMinMonotonic: """Test argmax/argmin rewrites for monotonic functions.""" @pytest.mark.parametrize("axis", [None, 0, -1]) @@ -5162,4 +5162,102 @@ def test_argmin_decreasing_functions(self, axis): ) assert not has_eliminated_op, ( f"Rewrite failed to eliminate arccos from argmin graph" - ) \ No newline at end of file + ) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_max_increasing_functions(self, axis): + """Test max(f_inc(x)) -> f_inc(max(x)) for monotonic increasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.exp, pt.log1p, pt.sqrt]: + unrewritten = pt.max(f(x), axis=axis) + expected = f(pt.max(x, axis=axis)) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + np.testing.assert_allclose(result_unrewritten, result_expected) + + # Verify rewrite structure: should have f wrapping max, not max wrapping f + topo = fn_unrewritten.maker.fgraph.toposort() + # The outer operation should be the monotonic function + assert isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op.scalar_op, type(f(pt.scalar()).owner.op.scalar_op)) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_min_increasing_functions(self, axis): + """Test min(f_inc(x)) -> f_inc(min(x)) for monotonic increasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.exp, pt.log1p, pt.sqrt]: + unrewritten = pt.min(f(x), axis=axis) + expected = f(pt.min(x, axis=axis)) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + np.testing.assert_allclose(result_unrewritten, result_expected) + + topo = fn_unrewritten.maker.fgraph.toposort() + assert isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op.scalar_op, type(f(pt.scalar()).owner.op.scalar_op)) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_max_decreasing_functions(self, axis): + """Test max(f_dec(x)) -> f_dec(min(x)) for monotonic decreasing f.""" + x = pt.vector("x") + test_val = np.array([0.1, 0.3, 0.2, 0.5, 0.4]) + + mode = get_default_mode() + + for f in [pt.arccos]: + unrewritten = pt.max(f(x), axis=axis) + expected = f(pt.min(x, axis=axis)) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + np.testing.assert_allclose(result_unrewritten, result_expected) + + topo = fn_unrewritten.maker.fgraph.toposort() + assert isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op.scalar_op, type(f(pt.scalar()).owner.op.scalar_op)) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_min_decreasing_functions(self, axis): + """Test min(f_dec(x)) -> f_dec(max(x)) for monotonic decreasing f.""" + x = pt.vector("x") + test_val = np.array([0.1, 0.3, 0.2, 0.5, 0.4]) + + mode = get_default_mode() + + for f in [pt.arccos]: + unrewritten = pt.min(f(x), axis=axis) + expected = f(pt.max(x, axis=axis)) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + np.testing.assert_allclose(result_unrewritten, result_expected) + + topo = fn_unrewritten.maker.fgraph.toposort() + assert isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op.scalar_op, type(f(pt.scalar()).owner.op.scalar_op)) \ No newline at end of file