From 1fcd9fc505343b9006e7fff940094d4c952700e1 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 11 Feb 2026 18:57:37 +0100 Subject: [PATCH] Stabilize logdiffexp_to_log1mexpdiff -inf, -inf case --- pytensor/tensor/rewriting/math.py | 7 +++++-- tests/tensor/rewriting/test_math.py | 7 ++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a9bc7b15cb..a7ef88c715 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -55,6 +55,7 @@ deg2rad, digamma, dot, + eq, erf, erfc, exp, @@ -3812,12 +3813,14 @@ def logmexpm1_to_log1mexp(fgraph, node): # log(exp(a) - exp(b)) -> a + log1mexp(b - a) +# special care is taken for a == b == -inf, by wrapping -> switch(b == -inf, a, ...) logdiffexp_to_log1mexpdiff = PatternNodeRewriter( (log, (sub, (exp, "x"), (exp, "y"))), - (add, "x", (log1mexp, (sub, "y", "x"))), + (switch, (eq, "y", -np.inf), "x", (add, "x", (log1mexp, (sub, "y", "x")))), allow_multiple_clients=True, + name="logdiffexp_to_log1mexpdiff", ) -register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff") +register_stabilize(logdiffexp_to_log1mexpdiff) # log(sigmoid(x) / (1 - sigmoid(x))) -> x # i.e logit(sigmoid(x)) -> x diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index bf2160aaf1..9c7319fda3 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4581,7 +4581,7 @@ def test_log1mexp_stabilization(op_name): ) -def test_logdiffexp(): +def test_logdiffexp_stabilization(): rng = np.random.default_rng(3559) mode = Mode("py").including("stabilize").excluding("fusion") @@ -4618,6 +4618,11 @@ def test_logdiffexp(): np.testing.assert_almost_equal( f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test)) ) + # Test edge cases + np.testing.assert_array_equal( + f([[-np.inf, -np.inf, -1]], [[-1, -np.inf, -np.inf]]), + [[np.nan, -np.inf, -1]], + ) def test_polygamma_specialization():