diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a9bc7b15cb..92b30926f1 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -41,12 +41,16 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast from pytensor.tensor.math import ( + Argmax, + Max, + Min, Dot, Prod, Sum, _conj, _dot, _matmul, + argmin, add, arccosh, arcsinh, @@ -121,6 +125,14 @@ TensorVariable, ) +MONOTONIC_INCREASING = ( + ps.Exp, ps.Exp2, ps.Expm1, ps.Log, ps.Log2, ps.Log10, ps.Log1p, + ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.ArcTan, + ps.ArcCosh, ps.Sinh, ps.ArcSinh, ps.Tanh, ps.ArcTanh +) + +MONOTONIC_DECREASING = (ps.Neg, ps.ArcCos) + def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): """Partition a list of variables into two kinds: @@ -3885,3 +3897,137 @@ def local_useless_conj(fgraph, node): ) register_stabilize(local_log_kv) + +def _is_argmin(node): + """Check if node represents argmin by detecting Argmax(Neg(...))""" + if not isinstance(node.op, Argmax): + return False + + input_node = node.inputs[0] + if not input_node.owner: + return False + + # argmin(x) becomes Argmax(Neg(x)) or Argmax(imax - x) or Argmax(~x) + inner_op = input_node.owner.op + if isinstance(inner_op, Elemwise) and isinstance(inner_op.scalar_op, ps.Neg): + return True + + return False + +@register_canonicalize +@node_rewriter([Argmax]) +def local_argmax_argmin_monotonic(fgraph, node): + """ + Optimize argmax/argmin with monotonic functions: + - argmax(f_inc(x)) -> argmax(x) for monotonically increasing f + - argmin(f_inc(x)) -> argmin(x) for monotonically increasing f + - argmax(f_dec(x)) -> argmin(x) for monotonically decreasing f + - argmin(f_dec(x)) -> argmax(x) for monotonically decreasing f + + Note: argmin is represented as Argmax(Neg(...)) internally + """ + + if not isinstance(node.op, Argmax): + return False + + is_argmin = _is_argmin(node) + argmax_input = node.inputs[0] + + # If argmin, skip the Neg wrapper to get to the monotonic function + if is_argmin: + if not argmax_input.owner: + return False + argmax_input = argmax_input.owner.inputs[0] # Skip Neg + + if not argmax_input.owner: + return False + + inner_op = argmax_input.owner.op + + if not isinstance(inner_op, Elemwise): + return False + + scalar_op = inner_op.scalar_op + + is_increasing = isinstance(scalar_op, MONOTONIC_INCREASING) + is_decreasing = isinstance(scalar_op, MONOTONIC_DECREASING) + + if not (is_increasing or is_decreasing): + return False + + x = argmax_input.owner.inputs[0] + + # Determine new operation based on current op and monotonicity + if is_argmin: + if is_increasing: + # argmin(f_inc(x)) -> argmin(x) = Argmax(Neg(x)) + new_output = argmin(x, axis=node.op.axis) + else: # is_decreasing + # argmin(f_dec(x)) -> argmax(x) + new_output = node.op(x) + else: # is argmax + if is_increasing: + # argmax(f_inc(x)) -> argmax(x) + new_output = node.op(x) + else: # is_decreasing + # argmax(f_dec(x)) -> argmin(x) = Argmax(Neg(x)) + new_output = argmin(x, axis=node.op.axis) + + copy_stack_trace(node.outputs[0], new_output) + return [new_output] + +@register_canonicalize +@node_rewriter([Max, Min]) +def local_max_min_monotonic(fgraph, node): + """ + Optimize max/min with monotonic functions by moving the function outside: + - max(f_inc(x)) -> f_inc(max(x)) for monotonically increasing f + - min(f_inc(x)) -> f_inc(min(x)) for monotonically increasing f + - max(f_dec(x)) -> f_dec(min(x)) for monotonically decreasing f + - min(f_dec(x)) -> f_dec(max(x)) for monotonically decreasing f + """ + if not isinstance(node.op, (Max, Min)): + return False + + is_max = isinstance(node.op, Max) + input_arg = node.inputs[0] + + if not input_arg.owner: + return False + + inner_op = input_arg.owner.op + + if not isinstance(inner_op, Elemwise): + return False + + scalar_op = inner_op.scalar_op + + is_increasing = isinstance(scalar_op, MONOTONIC_INCREASING) + is_decreasing = isinstance(scalar_op, MONOTONIC_DECREASING) + + if not (is_increasing or is_decreasing): + return False + + x = input_arg.owner.inputs[0] + + # Determine new operation based on current op and monotonicity + if is_max: + if is_increasing: + # max(f_inc(x)) -> f_inc(max(x)) + inner_result = node.op.make_node(x).outputs[0] + else: # is_decreasing + # max(f_dec(x)) -> f_dec(min(x)) + inner_result = Min(axis=node.op.axis)(x) + else: # is_min + if is_increasing: + # min(f_inc(x)) -> f_inc(min(x)) + inner_result = node.op.make_node(x).outputs[0] + else: # is_decreasing + # min(f_dec(x)) -> f_dec(max(x)) + inner_result = Max(axis=node.op.axis)(x) + + # Apply the monotonic function to the result + new_output = inner_op.make_node(inner_result).outputs[0] + + copy_stack_trace(node.outputs[0], new_output) + return [new_output] \ No newline at end of file diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index bf2160aaf1..99b8d487f4 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -5022,3 +5022,242 @@ def test_benchmark(self, benchmark, size, rewrite): c_val, d_val, ) + +class TestArgmaxArgminMaxMinMonotonic: + """Test argmax/argmin rewrites for monotonic functions.""" + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_argmax_increasing_functions(self, axis): + """Test argmax(f_inc(x)) -> argmax(x) for monotonic increasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.exp, pt.log1p, pt.sqrt]: + # Compile the unrewritten and expected graphs + unrewritten = pt.argmax(f(x), axis=axis) + expected = pt.argmax(x, axis=axis) + + # Create functions to apply rewrites + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + # Test numerical equality + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + assert result_unrewritten == result_expected, ( + f"argmax({f.__name__}(x), axis={axis}) failed: " + f"got {result_unrewritten}, expected {result_expected}" + ) + + # Verify the rewrite was applied (no Exp/Log1p/Sqrt in final graph) + topo = fn_unrewritten.maker.fgraph.toposort() + has_eliminated_op = any( + isinstance(node.op, Elemwise) and + isinstance(node.op.scalar_op, (ps.Exp, ps.Log1p, ps.Sqrt)) + for node in topo + ) + assert not has_eliminated_op, ( + f"Rewrite failed to eliminate {f.__name__} from argmax graph" + ) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_argmin_increasing_functions(self, axis): + """Test argmin(f_inc(x)) -> argmin(x) for monotonic increasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.exp, pt.log1p, pt.sqrt]: + unrewritten = pt.argmin(f(x), axis=axis) + expected = pt.argmin(x, axis=axis) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + assert result_unrewritten == result_expected, ( + f"argmin({f.__name__}(x), axis={axis}) failed: " + f"got {result_unrewritten}, expected {result_expected}" + ) + + topo = fn_unrewritten.maker.fgraph.toposort() + has_eliminated_op = any( + isinstance(node.op, Elemwise) and + isinstance(node.op.scalar_op, (ps.Exp, ps.Log1p, ps.Sqrt)) + for node in topo + ) + assert not has_eliminated_op, ( + f"Rewrite failed to eliminate {f.__name__} from argmin graph" + ) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_argmax_decreasing_functions(self, axis): + """Test argmax(f_dec(x)) -> argmin(x) for monotonic decreasing f.""" + x = pt.vector("x") + test_val = np.array([0.1, 0.3, 0.2, 0.5, 0.4]) # Values in (0, 1) for arccos + + mode = get_default_mode() + + for f in [pt.arccos]: + unrewritten = pt.argmax(f(x), axis=axis) + expected = pt.argmin(x, axis=axis) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + assert result_unrewritten == result_expected, ( + f"argmax({f.__name__}(x), axis={axis}) failed: " + f"got {result_unrewritten}, expected {result_expected}" + ) + + # Verify the rewrite was applied (no ArcCos in final graph) + topo = fn_unrewritten.maker.fgraph.toposort() + has_eliminated_op = any( + isinstance(node.op, Elemwise) and + isinstance(node.op.scalar_op, ps.ArcCos) # not testing for negative since argmin contains a neg itself + for node in topo + ) + assert not has_eliminated_op, ( + f"Rewrite failed to eliminate arccos from argmax graph" + ) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_argmin_decreasing_functions(self, axis): + """Test argmin(f_dec(x)) -> argmax(x) for monotonic decreasing f.""" + x = pt.vector("x") + test_val = np.array([0.1, 0.3, 0.2, 0.5, 0.4]) # Values in (0, 1) for arccos + + mode = get_default_mode() + + for f in [pt.arccos]: + unrewritten = pt.argmin(f(x), axis=axis) + expected = pt.argmax(x, axis=axis) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + assert result_unrewritten == result_expected, ( + f"argmin({f.__name__}(x), axis={axis}) failed: " + f"got {result_unrewritten}, expected {result_expected}" + ) + + # Verify the rewrite was applied (no ArcCos in final graph) + topo = fn_unrewritten.maker.fgraph.toposort() + has_eliminated_op = any( + isinstance(node.op, Elemwise) and + isinstance(node.op.scalar_op, ps.ArcCos) + for node in topo + ) + assert not has_eliminated_op, ( + f"Rewrite failed to eliminate arccos from argmin graph" + ) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_max_increasing_functions(self, axis): + """Test max(f_inc(x)) -> f_inc(max(x)) for monotonic increasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.exp, pt.log1p, pt.sqrt]: + unrewritten = pt.max(f(x), axis=axis) + expected = f(pt.max(x, axis=axis)) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + np.testing.assert_allclose(result_unrewritten, result_expected) + + # Verify rewrite structure: should have f wrapping max, not max wrapping f + topo = fn_unrewritten.maker.fgraph.toposort() + # The outer operation should be the monotonic function + assert isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op.scalar_op, type(f(pt.scalar()).owner.op.scalar_op)) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_min_increasing_functions(self, axis): + """Test min(f_inc(x)) -> f_inc(min(x)) for monotonic increasing f.""" + x = pt.vector("x") + test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) + + mode = get_default_mode() + + for f in [pt.exp, pt.log1p, pt.sqrt]: + unrewritten = pt.min(f(x), axis=axis) + expected = f(pt.min(x, axis=axis)) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + np.testing.assert_allclose(result_unrewritten, result_expected) + + topo = fn_unrewritten.maker.fgraph.toposort() + assert isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op.scalar_op, type(f(pt.scalar()).owner.op.scalar_op)) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_max_decreasing_functions(self, axis): + """Test max(f_dec(x)) -> f_dec(min(x)) for monotonic decreasing f.""" + x = pt.vector("x") + test_val = np.array([0.1, 0.3, 0.2, 0.5, 0.4]) + + mode = get_default_mode() + + for f in [pt.arccos]: + unrewritten = pt.max(f(x), axis=axis) + expected = f(pt.min(x, axis=axis)) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + np.testing.assert_allclose(result_unrewritten, result_expected) + + topo = fn_unrewritten.maker.fgraph.toposort() + assert isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op.scalar_op, type(f(pt.scalar()).owner.op.scalar_op)) + + @pytest.mark.parametrize("axis", [None, 0, -1]) + def test_min_decreasing_functions(self, axis): + """Test min(f_dec(x)) -> f_dec(max(x)) for monotonic decreasing f.""" + x = pt.vector("x") + test_val = np.array([0.1, 0.3, 0.2, 0.5, 0.4]) + + mode = get_default_mode() + + for f in [pt.arccos]: + unrewritten = pt.min(f(x), axis=axis) + expected = f(pt.max(x, axis=axis)) + + fn_unrewritten = function([x], unrewritten, mode=mode) + fn_expected = function([x], expected, mode=mode) + + result_unrewritten = fn_unrewritten(test_val) + result_expected = fn_expected(test_val) + + np.testing.assert_allclose(result_unrewritten, result_expected) + + topo = fn_unrewritten.maker.fgraph.toposort() + assert isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op.scalar_op, type(f(pt.scalar()).owner.op.scalar_op)) \ No newline at end of file