diff --git a/tenacity/retry.py b/tenacity/retry.py index 593b252d..4edbd1f7 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -29,16 +29,20 @@ class retry_base(abc.ABC): def __call__(self, retry_state: "RetryCallState") -> bool: pass - def __and__(self, other: "retry_base") -> "retry_all": - return other.__rand__(self) + def __and__(self, other: "RetryBaseT") -> "retry_all": + if isinstance(other, retry_base): + return other.__rand__(self) + return retry_all(self, other) - def __rand__(self, other: "retry_base") -> "retry_all": + def __rand__(self, other: "RetryBaseT") -> "retry_all": return retry_all(other, self) - def __or__(self, other: "retry_base") -> "retry_any": - return other.__ror__(self) + def __or__(self, other: "RetryBaseT") -> "retry_any": + if isinstance(other, retry_base): + return other.__ror__(self) + return retry_any(self, other) - def __ror__(self, other: "retry_base") -> "retry_any": + def __ror__(self, other: "RetryBaseT") -> "retry_any": return retry_any(other, self) @@ -254,7 +258,7 @@ def __call__(self, retry_state: "RetryCallState") -> bool: class retry_any(retry_base): """Retries if any of the retries condition is valid.""" - def __init__(self, *retries: retry_base) -> None: + def __init__(self, *retries: "RetryBaseT") -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: @@ -264,7 +268,7 @@ def __call__(self, retry_state: "RetryCallState") -> bool: class retry_all(retry_base): """Retries if all the retries condition are valid.""" - def __init__(self, *retries: retry_base) -> None: + def __init__(self, *retries: "RetryBaseT") -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index 0645a835..5a90a467 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -743,6 +743,40 @@ def r(fut: tenacity.Future) -> bool: self.assertFalse(r(tenacity.Future.construct(1, 2.2, False))) self.assertFalse(r(tenacity.Future.construct(1, 42, True))) + def test_retry_or_with_plain_function(self) -> None: + """Plain callables can be composed with retry_base via |.""" + + def my_retry(retry_state: tenacity.RetryCallState) -> bool: + return retry_state.outcome is not None and not retry_state.outcome.failed + + # retry_base | plain_callable (exercises __or__ fallback) + retry = tenacity.retry_if_exception_type(Exception) | my_retry + retry_state = make_retry_state( + 1, 1.0, last_result=tenacity.Future.construct(1, "ok", False) + ) + self.assertTrue(retry(retry_state)) + + # plain_callable | retry_base (exercises __ror__ via reflection) + retry2 = my_retry | tenacity.retry_if_exception_type(Exception) + self.assertTrue(retry2(retry_state)) + + def test_retry_and_with_plain_function(self) -> None: + """Plain callables can be composed with retry_base via &.""" + + def my_retry(retry_state: tenacity.RetryCallState) -> bool: + return True + + # retry_base & plain_callable (exercises __and__ fallback) + retry = tenacity.retry_if_result(lambda x: x == 1) & my_retry + retry_state = make_retry_state( + 1, 1.0, last_result=tenacity.Future.construct(1, 1, False) + ) + self.assertTrue(retry(retry_state)) + + # plain_callable & retry_base (exercises __rand__ via reflection) + retry2 = my_retry & tenacity.retry_if_result(lambda x: x == 1) + self.assertTrue(retry2(retry_state)) + def _raise_try_again(self) -> None: self._attempts += 1 if self._attempts < 3: