Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
from pytensor.tensor.math import (
Argmax,
Max,
Min,
Dot,
Prod,
Sum,
_conj,
_dot,
_matmul,
argmin,
add,
arccosh,
arcsinh,
Expand Down Expand Up @@ -121,6 +125,14 @@
TensorVariable,
)

MONOTONIC_INCREASING = (
ps.Exp, ps.Exp2, ps.Expm1, ps.Log, ps.Log2, ps.Log10, ps.Log1p,
ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.ArcTan,
ps.ArcCosh, ps.Sinh, ps.ArcSinh, ps.Tanh, ps.ArcTanh
)

MONOTONIC_DECREASING = (ps.Neg, ps.ArcCos)


def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
"""Partition a list of variables into two kinds:
Expand Down Expand Up @@ -3885,3 +3897,137 @@ def local_useless_conj(fgraph, node):
)

register_stabilize(local_log_kv)

def _is_argmin(node):
"""Check if node represents argmin by detecting Argmax(Neg(...))"""
if not isinstance(node.op, Argmax):
return False

input_node = node.inputs[0]
if not input_node.owner:
return False

# argmin(x) becomes Argmax(Neg(x)) or Argmax(imax - x) or Argmax(~x)
inner_op = input_node.owner.op
if isinstance(inner_op, Elemwise) and isinstance(inner_op.scalar_op, ps.Neg):
return True

return False

@register_canonicalize
@node_rewriter([Argmax])
def local_argmax_argmin_monotonic(fgraph, node):
"""
Optimize argmax/argmin with monotonic functions:
- argmax(f_inc(x)) -> argmax(x) for monotonically increasing f
- argmin(f_inc(x)) -> argmin(x) for monotonically increasing f
- argmax(f_dec(x)) -> argmin(x) for monotonically decreasing f
- argmin(f_dec(x)) -> argmax(x) for monotonically decreasing f

Note: argmin is represented as Argmax(Neg(...)) internally
"""

if not isinstance(node.op, Argmax):
return False

is_argmin = _is_argmin(node)
argmax_input = node.inputs[0]

# If argmin, skip the Neg wrapper to get to the monotonic function
if is_argmin:
if not argmax_input.owner:
return False
argmax_input = argmax_input.owner.inputs[0] # Skip Neg

if not argmax_input.owner:
return False

inner_op = argmax_input.owner.op

if not isinstance(inner_op, Elemwise):
return False

scalar_op = inner_op.scalar_op

is_increasing = isinstance(scalar_op, MONOTONIC_INCREASING)
is_decreasing = isinstance(scalar_op, MONOTONIC_DECREASING)

if not (is_increasing or is_decreasing):
return False

x = argmax_input.owner.inputs[0]

# Determine new operation based on current op and monotonicity
if is_argmin:
if is_increasing:
# argmin(f_inc(x)) -> argmin(x) = Argmax(Neg(x))
new_output = argmin(x, axis=node.op.axis)
else: # is_decreasing
# argmin(f_dec(x)) -> argmax(x)
new_output = node.op(x)
else: # is argmax
if is_increasing:
# argmax(f_inc(x)) -> argmax(x)
new_output = node.op(x)
else: # is_decreasing
# argmax(f_dec(x)) -> argmin(x) = Argmax(Neg(x))
new_output = argmin(x, axis=node.op.axis)

copy_stack_trace(node.outputs[0], new_output)
return [new_output]

@register_canonicalize
@node_rewriter([Max, Min])
def local_max_min_monotonic(fgraph, node):
"""
Optimize max/min with monotonic functions by moving the function outside:
- max(f_inc(x)) -> f_inc(max(x)) for monotonically increasing f
- min(f_inc(x)) -> f_inc(min(x)) for monotonically increasing f
- max(f_dec(x)) -> f_dec(min(x)) for monotonically decreasing f
- min(f_dec(x)) -> f_dec(max(x)) for monotonically decreasing f
"""
if not isinstance(node.op, (Max, Min)):
return False

is_max = isinstance(node.op, Max)
input_arg = node.inputs[0]

if not input_arg.owner:
return False

inner_op = input_arg.owner.op

if not isinstance(inner_op, Elemwise):
return False

scalar_op = inner_op.scalar_op

is_increasing = isinstance(scalar_op, MONOTONIC_INCREASING)
is_decreasing = isinstance(scalar_op, MONOTONIC_DECREASING)

if not (is_increasing or is_decreasing):
return False

x = input_arg.owner.inputs[0]

# Determine new operation based on current op and monotonicity
if is_max:
if is_increasing:
# max(f_inc(x)) -> f_inc(max(x))
inner_result = node.op.make_node(x).outputs[0]
else: # is_decreasing
# max(f_dec(x)) -> f_dec(min(x))
inner_result = Min(axis=node.op.axis)(x)
else: # is_min
if is_increasing:
# min(f_inc(x)) -> f_inc(min(x))
inner_result = node.op.make_node(x).outputs[0]
else: # is_decreasing
# min(f_dec(x)) -> f_dec(max(x))
inner_result = Max(axis=node.op.axis)(x)

# Apply the monotonic function to the result
new_output = inner_op.make_node(inner_result).outputs[0]

copy_stack_trace(node.outputs[0], new_output)
return [new_output]
Loading