From 952b9dfcf7a2547901de8604b572e8ac94262616 Mon Sep 17 00:00:00 2001 From: Julien Danjou Date: Tue, 24 Feb 2026 16:08:18 +0100 Subject: [PATCH] fix: restore plain function composition with retry_base via | and & The delegation to __ror__/__rand__ introduced for async retry support broke composing retry_base instances with plain callables. Add an isinstance guard so delegation only happens for retry_base subclasses, while plain callables fall through to direct retry_all/retry_any wrapping. Also widen type annotations on retry_any/retry_all and the operator methods to accept RetryBaseT (the union of retry_base and plain callables), matching the runtime behavior. Fixes #481 Co-Authored-By: Claude Opus 4.6 Change-Id: I24e081b67dfa7e184266621f73cec4de5177ea25 --- tenacity/retry.py | 20 ++++++++++++-------- tests/test_tenacity.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/tenacity/retry.py b/tenacity/retry.py index 593b252d..4edbd1f7 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -29,16 +29,20 @@ class retry_base(abc.ABC): def __call__(self, retry_state: "RetryCallState") -> bool: pass - def __and__(self, other: "retry_base") -> "retry_all": - return other.__rand__(self) + def __and__(self, other: "RetryBaseT") -> "retry_all": + if isinstance(other, retry_base): + return other.__rand__(self) + return retry_all(self, other) - def __rand__(self, other: "retry_base") -> "retry_all": + def __rand__(self, other: "RetryBaseT") -> "retry_all": return retry_all(other, self) - def __or__(self, other: "retry_base") -> "retry_any": - return other.__ror__(self) + def __or__(self, other: "RetryBaseT") -> "retry_any": + if isinstance(other, retry_base): + return other.__ror__(self) + return retry_any(self, other) - def __ror__(self, other: "retry_base") -> "retry_any": + def __ror__(self, other: "RetryBaseT") -> "retry_any": return retry_any(other, self) @@ -254,7 +258,7 @@ def __call__(self, retry_state: "RetryCallState") -> bool: class retry_any(retry_base): """Retries if any of the retries condition is valid.""" - def __init__(self, *retries: retry_base) -> None: + def __init__(self, *retries: "RetryBaseT") -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: @@ -264,7 +268,7 @@ def __call__(self, retry_state: "RetryCallState") -> bool: class retry_all(retry_base): """Retries if all the retries condition are valid.""" - def __init__(self, *retries: retry_base) -> None: + def __init__(self, *retries: "RetryBaseT") -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index 0645a835..5a90a467 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -743,6 +743,40 @@ def r(fut: tenacity.Future) -> bool: self.assertFalse(r(tenacity.Future.construct(1, 2.2, False))) self.assertFalse(r(tenacity.Future.construct(1, 42, True))) + def test_retry_or_with_plain_function(self) -> None: + """Plain callables can be composed with retry_base via |.""" + + def my_retry(retry_state: tenacity.RetryCallState) -> bool: + return retry_state.outcome is not None and not retry_state.outcome.failed + + # retry_base | plain_callable (exercises __or__ fallback) + retry = tenacity.retry_if_exception_type(Exception) | my_retry + retry_state = make_retry_state( + 1, 1.0, last_result=tenacity.Future.construct(1, "ok", False) + ) + self.assertTrue(retry(retry_state)) + + # plain_callable | retry_base (exercises __ror__ via reflection) + retry2 = my_retry | tenacity.retry_if_exception_type(Exception) + self.assertTrue(retry2(retry_state)) + + def test_retry_and_with_plain_function(self) -> None: + """Plain callables can be composed with retry_base via &.""" + + def my_retry(retry_state: tenacity.RetryCallState) -> bool: + return True + + # retry_base & plain_callable (exercises __and__ fallback) + retry = tenacity.retry_if_result(lambda x: x == 1) & my_retry + retry_state = make_retry_state( + 1, 1.0, last_result=tenacity.Future.construct(1, 1, False) + ) + self.assertTrue(retry(retry_state)) + + # plain_callable & retry_base (exercises __rand__ via reflection) + retry2 = my_retry & tenacity.retry_if_result(lambda x: x == 1) + self.assertTrue(retry2(retry_state)) + def _raise_try_again(self) -> None: self._attempts += 1 if self._attempts < 3: