diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py index cbb332e3..e0d44e82 100644 --- a/tenacity/asyncio/retry.py +++ b/tenacity/asyncio/retry.py @@ -106,6 +106,13 @@ async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore break return result + def __ror__( # type: ignore[misc,override] + self, other: "retry_base | async_retry_base" + ) -> "retry_any": + if isinstance(other, retry_any): + return retry_any(*other.retries, *self.retries) + return retry_any(other, *self.retries) + class retry_all(async_retry_base): """Retries if all the retries condition are valid.""" @@ -120,3 +127,10 @@ async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore if not result: break return result + + def __rand__( # type: ignore[misc,override] + self, other: "retry_base | async_retry_base" + ) -> "retry_all": + if isinstance(other, retry_all): + return retry_all(*other.retries, *self.retries) + return retry_all(other, *self.retries) diff --git a/tenacity/retry.py b/tenacity/retry.py index 4edbd1f7..6e9fc3e8 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -32,17 +32,29 @@ def __call__(self, retry_state: "RetryCallState") -> bool: def __and__(self, other: "RetryBaseT") -> "retry_all": if isinstance(other, retry_base): return other.__rand__(self) + # Plain callable: flatten if self is already a retry_all + if isinstance(self, retry_all): + return retry_all(*self.retries, other) return retry_all(self, other) def __rand__(self, other: "RetryBaseT") -> "retry_all": + # Flatten if other is already a retry_all + if isinstance(other, retry_all): + return retry_all(*other.retries, self) return retry_all(other, self) def __or__(self, other: "RetryBaseT") -> "retry_any": if isinstance(other, retry_base): return other.__ror__(self) + # Plain callable: flatten if self is already a retry_any + if isinstance(self, retry_any): + return retry_any(*self.retries, other) return retry_any(self, other) def __ror__(self, other: "RetryBaseT") -> "retry_any": + # Flatten if other is already a retry_any + if isinstance(other, retry_any): + return retry_any(*other.retries, self) return retry_any(other, self) @@ -264,6 +276,11 @@ def __init__(self, *retries: "RetryBaseT") -> None: def __call__(self, retry_state: "RetryCallState") -> bool: return any(r(retry_state) for r in self.retries) + def __ror__(self, other: "RetryBaseT") -> "retry_any": + if isinstance(other, retry_any): + return retry_any(*other.retries, *self.retries) + return retry_any(other, *self.retries) + class retry_all(retry_base): """Retries if all the retries condition are valid.""" @@ -273,3 +290,8 @@ def __init__(self, *retries: "RetryBaseT") -> None: def __call__(self, retry_state: "RetryCallState") -> bool: return all(r(retry_state) for r in self.retries) + + def __rand__(self, other: "RetryBaseT") -> "retry_all": + if isinstance(other, retry_all): + return retry_all(*other.retries, *self.retries) + return retry_all(other, *self.retries) diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index b72831eb..80060a36 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -31,6 +31,7 @@ import tenacity from tenacity import RetryCallState, RetryError, Retrying, retry +from tenacity.retry import retry_all, retry_any _unset = object() @@ -777,6 +778,26 @@ def my_retry(retry_state: tenacity.RetryCallState) -> bool: retry2 = my_retry & tenacity.retry_if_result(lambda x: x == 1) self.assertTrue(retry2(retry_state)) + def test_retry_or_coalesces(self) -> None: + """Multiple | operations flatten into a single retry_any.""" + a = tenacity.retry_if_exception_type(IOError) + b = tenacity.retry_if_exception_type(OSError) + c = tenacity.retry_if_exception_type(ValueError) + + combined = a | b | c + self.assertIsInstance(combined, retry_any) + self.assertEqual(len(combined.retries), 3) + + def test_retry_and_coalesces(self) -> None: + """Multiple & operations flatten into a single retry_all.""" + a = tenacity.retry_if_result(lambda x: x == 1) + b = tenacity.retry_if_result(lambda x: x > 0) + c = tenacity.retry_if_result(lambda x: x < 10) + + combined = a & b & c + self.assertIsInstance(combined, retry_all) + self.assertEqual(len(combined.retries), 3) + def _raise_try_again(self) -> None: self._attempts += 1 if self._attempts < 3: