Add rewrite for argmax/argmin of monotonic functions#1869
Open
Jasjeet-Singh-S wants to merge 3 commits intopymc-devs:mainfrom
Open
Add rewrite for argmax/argmin of monotonic functions#1869Jasjeet-Singh-S wants to merge 3 commits intopymc-devs:mainfrom
Jasjeet-Singh-S wants to merge 3 commits intopymc-devs:mainfrom
Conversation
Implements graph rewrite that eliminates redundant monotonic function applications in argmax/argmin operations. For monotonically increasing functions, rewrites argmax(f(x)) → argmax(x) and argmin(f(x)) → argmin(x). For decreasing functions, flips operations: argmax(f(x)) → argmin(x) and argmin(f(x)) → argmax(x). Includes comprehensive tests.
There was a problem hiding this comment.
Pull request overview
This PR adds a graph rewrite optimization that eliminates unnecessary function evaluations when computing argmax or argmin of monotonic functions. The optimization leverages the property that monotonic functions preserve ordering, so argmax(exp(x)) can be simplified to argmax(x).
Changes:
- Adds
MONOTONIC_INCREASINGandMONOTONIC_DECREASINGtuples to classify scalar operations by monotonicity - Implements
local_argmax_argmin_monotonicrewriter that optimizes argmax/argmin of monotonic functions - Adds comprehensive test suite with parametrized tests for different axis values
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
pytensor/tensor/rewriting/math.py |
Adds monotonic function classifications and implements the core rewrite logic for argmax/argmin optimization |
tests/tensor/rewriting/test_math.py |
Adds test class with parametrized tests for increasing and decreasing monotonic functions |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add rewrite for argmax/argmin/max/min of monotonic functions
Closes #1851
Summary
This PR implements graph rewrites that optimize
argmax/argmin/max/minoperations applied to monotonic functions by eliminating unnecessary function evaluations.Motivation
Computing
argmax(exp(x))ormax(exp(x))is wasteful because the exponential computation doesn't affect which index has the maximum value or the relative ordering - we only care about the ordering relationship. Since monotonic functions preserve ordering, we can skip expensive function applications entirely for argmax/argmin, or move them outside the reduction for max/min.Implementation
New rewrites:
local_argmax_argmin_monotonic- for argmax/argmin operationslocal_max_min_monotonic- for max/min operationsArgmax/Argmin Rewrites
The rewrite handles four transformation paths based on function monotonicity:
Monotonically Increasing Functions
argmax(f(x)) → argmax(x)argmin(f(x)) → argmin(x)Supported increasing functions:
Exp,Exp2,Expm1,Log,Log2,Log10,Log1p,Sqrt,Deg2Rad,Rad2Deg,ArcSin,Tan,ArcTan,ArcCosh,Sinh,ArcSinh,Tanh,ArcTanhMonotonically Decreasing Functions
argmax(f(x)) → argmin(x)argmin(f(x)) → argmax(x)Supported decreasing functions:
Neg,Reciprocal,ArcCosMax/Min Rewrites
The rewrite handles four transformation paths by moving the monotonic function outside the reduction:
Monotonically Increasing Functions
max(f(x)) → f(max(x))min(f(x)) → f(min(x))Monotonically Decreasing Functions
max(f(x)) → f(min(x))min(f(x)) → f(max(x))Same supported increasing and decreasing functions as argmax/argmin.
Key Features
argminwhich is internally represented asArgmax(Neg(...))in PyTensorNone,0,-1, etc.)Elemwisewrapper detection to identify scalar operationscopy_stack_traceChanges
pytensor/tensor/rewriting/math.pyMONOTONIC_INCREASINGtuple containing 18 monotonically increasing scalar operationsMONOTONIC_DECREASINGtuple containing 3 monotonically decreasing scalar operations_is_argmin()helper function to detect argmin patterns (handlesArgmax(Neg(...))representation)local_argmax_argmin_monotonic()rewriter for argmax/argmin with@register_canonicalizedecoratorlocal_max_min_monotonic()rewriter for max/min with@register_canonicalizedecoratortests/tensor/rewriting/test_math.pyTestArgmaxArgminMonotonictest class with comprehensive coverage:test_argmax_increasing_functions- Tests argmax rewrite for increasing functionstest_argmin_increasing_functions- Tests argmin rewrite for increasing functionstest_argmax_decreasing_functions- Tests argmax rewrite for decreasing functions (flips to argmin)test_argmin_decreasing_functions- Tests argmin rewrite for decreasing functions (flips to argmax)test_max_increasing_functions- Tests max rewrite for increasing functionstest_min_increasing_functions- Tests min rewrite for increasing functionstest_max_decreasing_functions- Tests max rewrite for decreasing functionstest_min_decreasing_functions- Tests min rewrite for decreasing functionsNone,0,-1)arccosinstead ofnegfor decreasing function tests to avoid confusion with argmin's internal representationExample
Argmax/Argmin Example
Max/Min Example
Performance Impact
These rewrites provide significant speedups when:
Testing
All 24 tests pass with various configurations:
None,0,-1)The rewrites correctly handle edge cases including:
Argmax(Neg(...))representation forargmin