From 8601f7688d231ef78a5bd43bd2a93593a7b34a16 Mon Sep 17 00:00:00 2001 From: Julien Danjou Date: Tue, 24 Feb 2026 17:40:51 +0100 Subject: [PATCH] feat: support retrying generator and async generator functions Fixes #63 Add retry support for sync and async generator functions. When a generator decorated with @retry raises an exception, the retry logic kicks in and re-calls the generator function, continuing to yield values from the new generator. - Add is_generator_callable() and is_async_gen_callable() detection utilities in _utils.py - Add sync generator wrapper in BaseRetrying.wraps() - Add async generator wrapper in AsyncRetrying.wraps() - Route async generators to AsyncRetrying in retry() Co-Authored-By: Claude Opus 4.6 Change-Id: Ie49981e21ff7cae602b69e63c567d007d8b78a65 --- tenacity/__init__.py | 46 ++++++++++++++-- tenacity/_utils.py | 20 +++++++ tenacity/asyncio/__init__.py | 47 ++++++++++++++-- tests/test_asyncio.py | 88 ++++++++++++++++++++++++++++++ tests/test_tenacity.py | 103 +++++++++++++++++++++++++++++++++++ 5 files changed, 294 insertions(+), 10 deletions(-) diff --git a/tenacity/__init__.py b/tenacity/__init__.py index 591f9703..32ae11d2 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -356,15 +356,47 @@ def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any: wrapped_f.statistics = copy.statistics # type: ignore[attr-defined] return copy(f, *args, **kw) + @functools.wraps( + f, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__") + ) + def wrapped_gen_f( + *args: t.Any, **kw: t.Any + ) -> t.Generator[t.Any, t.Any, t.Any]: + if not self.enabled: + yield from f(*args, **kw) + return + copy = self.copy() + wrapped_gen_f.statistics = copy.statistics # type: ignore[attr-defined] + copy.begin() + retry_state = RetryCallState(retry_object=copy, fn=f, args=args, kwargs=kw) + while True: + do = copy.iter(retry_state=retry_state) + if isinstance(do, DoAttempt): + try: + result = yield from f(*args, **kw) + except GeneratorExit: + raise + except BaseException: + retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type] + else: + retry_state.set_result(result) + elif isinstance(do, DoSleep): + retry_state.prepare_for_next_attempt() + copy.sleep(do) + else: + return do + + result_f = wrapped_gen_f if _utils.is_generator_callable(f) else wrapped_f + def retry_with(*args: t.Any, **kwargs: t.Any) -> "_RetryDecorated[P, R]": return self.copy(*args, **kwargs).wraps(f) # Preserve attributes - wrapped_f.retry = self # type: ignore[attr-defined] - wrapped_f.retry_with = retry_with # type: ignore[attr-defined] - wrapped_f.statistics = {} # type: ignore[attr-defined] + result_f.retry = self # type: ignore[attr-defined] + result_f.retry_with = retry_with # type: ignore[attr-defined] + result_f.statistics = {} # type: ignore[attr-defined] - return t.cast("_RetryDecorated[P, R]", wrapped_f) + return t.cast("_RetryDecorated[P, R]", result_f) def begin(self) -> None: self.statistics.clear() @@ -714,8 +746,10 @@ def wrap(f: t.Callable[P, R]) -> _RetryDecorated[P, R]: ) r: BaseRetrying sleep = dkw.get("sleep") - if _utils.is_coroutine_callable(f) or ( - sleep is not None and _utils.is_coroutine_callable(sleep) + if ( + _utils.is_coroutine_callable(f) + or _utils.is_async_gen_callable(f) + or (sleep is not None and _utils.is_coroutine_callable(sleep)) ): r = AsyncRetrying(*dargs, **dkw) elif ( diff --git a/tenacity/_utils.py b/tenacity/_utils.py index 6a12678b..bb3a059e 100644 --- a/tenacity/_utils.py +++ b/tenacity/_utils.py @@ -99,6 +99,26 @@ def is_coroutine_callable(call: typing.Callable[..., typing.Any]) -> bool: return inspect.iscoroutinefunction(dunder_call) +def is_generator_callable(call: typing.Callable[..., typing.Any]) -> bool: + if inspect.isclass(call): + return False + if inspect.isgeneratorfunction(call): + return True + partial_call = isinstance(call, functools.partial) and call.func + dunder_call = partial_call or getattr(call, "__call__", None) # noqa: B004 + return inspect.isgeneratorfunction(dunder_call) + + +def is_async_gen_callable(call: typing.Callable[..., typing.Any]) -> bool: + if inspect.isclass(call): + return False + if inspect.isasyncgenfunction(call): + return True + partial_call = isinstance(call, functools.partial) and call.func + dunder_call = partial_call or getattr(call, "__call__", None) # noqa: B004 + return inspect.isasyncgenfunction(dunder_call) + + def wrap_to_async_func( call: typing.Callable[..., typing.Any], ) -> typing.Callable[..., typing.Awaitable[typing.Any]]: diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py index 0ff48855..389da591 100644 --- a/tenacity/asyncio/__init__.py +++ b/tenacity/asyncio/__init__.py @@ -199,12 +199,51 @@ async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: async_wrapped.statistics = copy.statistics # type: ignore[attr-defined] return await copy(fn, *args, **kwargs) # type: ignore[type-var] + @functools.wraps( + fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__") + ) + async def async_wrapped_gen( + *args: t.Any, **kwargs: t.Any + ) -> t.AsyncGenerator[t.Any, t.Any]: + if not self.enabled: + async for item in fn(*args, **kwargs): # type: ignore[misc] + yield item + return + copy = self.copy() + async_wrapped_gen.statistics = copy.statistics # type: ignore[attr-defined] + copy.begin() + retry_state = RetryCallState( + retry_object=copy, fn=fn, args=args, kwargs=kwargs + ) + while True: + do = await copy.iter(retry_state=retry_state) + if isinstance(do, DoAttempt): + try: + async for item in fn(*args, **kwargs): # type: ignore[misc] + yield item + except GeneratorExit: + raise + except BaseException: + retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type] + else: + retry_state.set_result(None) + elif isinstance(do, DoSleep): + retry_state.prepare_for_next_attempt() + await self.sleep(do) # type: ignore[misc] + else: + return + + if _utils.is_async_gen_callable(fn): + result_f = async_wrapped_gen + else: + result_f = async_wrapped + # Preserve attributes - async_wrapped.retry = self # type: ignore[attr-defined] - async_wrapped.retry_with = wrapped.retry_with # type: ignore[attr-defined] - async_wrapped.statistics = {} # type: ignore[attr-defined] + result_f.retry = self # type: ignore[attr-defined] + result_f.retry_with = wrapped.retry_with # type: ignore[attr-defined] + result_f.statistics = {} # type: ignore[attr-defined] - return t.cast("_RetryDecorated[P, R]", async_wrapped) + return t.cast("_RetryDecorated[P, R]", result_f) __all__ = [ diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 749c8834..ce693ed7 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -531,5 +531,93 @@ def sync_function() -> Any: assert mock_sleep.await_count == 2 +class TestAsyncGeneratorRetry(unittest.TestCase): + @asynctest + async def test_async_generator_retry_on_exception(self) -> None: + attempts = 0 + + @retry( + stop=stop_after_attempt(3), + retry=tenacity.retry_if_exception_type(ValueError), + reraise=True, + ) + async def gen_with_errors() -> Any: + nonlocal attempts + attempts += 1 + if attempts < 3: + raise ValueError("not yet") + yield 1 + yield 2 + yield 3 + + result = [item async for item in gen_with_errors()] + assert result == [1, 2, 3] + assert attempts == 3 + + @asynctest + async def test_async_generator_yields_all_values(self) -> None: + @retry + async def simple_gen() -> Any: + yield 10 + yield 20 + yield 30 + + result = [item async for item in simple_gen()] + assert result == [10, 20, 30] + + @asynctest + async def test_async_generator_stop_after_attempt(self) -> None: + @retry( + stop=stop_after_attempt(2), + retry=tenacity.retry_if_exception_type(RuntimeError), + ) + async def always_fails() -> Any: + raise RuntimeError("always") + yield # make it an async generator + + with pytest.raises(RetryError): + async for _ in always_fails(): + pass + + def test_async_generator_has_retry_attributes(self) -> None: + @retry + async def my_gen() -> Any: + yield 1 + + assert hasattr(my_gen, "retry") + assert hasattr(my_gen, "statistics") + assert hasattr(my_gen, "retry_with") + + @asynctest + async def test_async_generator_statistics_updated(self) -> None: + attempts = 0 + + @retry( + stop=stop_after_attempt(3), + retry=tenacity.retry_if_exception_type(ValueError), + reraise=True, + ) + async def gen_stats() -> Any: + nonlocal attempts + attempts += 1 + if attempts < 2: + raise ValueError("retry") + yield 42 + + result = [item async for item in gen_stats()] + assert result == [42] + assert gen_stats.statistics["attempt_number"] == 2 # type: ignore[attr-defined] + + @asynctest + async def test_async_generator_enabled_false(self) -> None: + @retry(enabled=False) + async def my_gen() -> Any: + yield 1 + yield 2 + + result = [item async for item in my_gen()] + assert result == [1, 2] + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index 80060a36..d6938923 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -2009,5 +2009,108 @@ def test_decorated_retry_with(self, mock_sleep: typing.Any) -> None: assert mock_sleep.call_count == 1 +class TestGeneratorRetry: + def test_generator_retry_on_exception(self) -> None: + attempts = 0 + + @retry( + stop=tenacity.stop_after_attempt(3), + retry=tenacity.retry_if_exception_type(ValueError), + reraise=True, + ) + def gen_with_errors() -> typing.Generator[int, None, None]: + nonlocal attempts + attempts += 1 + if attempts < 3: + raise ValueError("not yet") + yield 1 + yield 2 + yield 3 + + result = list(gen_with_errors()) + assert result == [1, 2, 3] + assert attempts == 3 + + def test_generator_yields_all_values(self) -> None: + @retry + def simple_gen() -> typing.Generator[int, None, None]: + yield 10 + yield 20 + yield 30 + + result = list(simple_gen()) + assert result == [10, 20, 30] + + def test_generator_stop_after_attempt(self) -> None: + @retry( + stop=tenacity.stop_after_attempt(2), + retry=tenacity.retry_if_exception_type(RuntimeError), + ) + def always_fails() -> typing.Generator[int, None, None]: + raise RuntimeError("always") + yield # make it a generator + + with pytest.raises(RetryError): + list(always_fails()) + + def test_generator_has_retry_attributes(self) -> None: + @retry + def my_gen() -> typing.Generator[int, None, None]: + yield 1 + + assert hasattr(my_gen, "retry") + assert hasattr(my_gen, "statistics") + assert hasattr(my_gen, "retry_with") + + def test_generator_statistics_updated(self) -> None: + attempts = 0 + + @retry( + stop=tenacity.stop_after_attempt(3), + retry=tenacity.retry_if_exception_type(ValueError), + reraise=True, + ) + def gen_stats() -> typing.Generator[int, None, None]: + nonlocal attempts + attempts += 1 + if attempts < 2: + raise ValueError("retry") + yield 42 + + result = list(gen_stats()) + assert result == [42] + assert gen_stats.statistics["attempt_number"] == 2 # type: ignore[attr-defined] + + def test_generator_enabled_false(self) -> None: + @retry(enabled=False) + def my_gen() -> typing.Generator[int, None, None]: + yield 1 + yield 2 + + result = list(my_gen()) + assert result == [1, 2] + + def test_generator_retry_with(self) -> None: + attempts = 0 + + @retry( + stop=tenacity.stop_after_attempt(5), + retry=tenacity.retry_if_exception_type(ValueError), + ) + def gen_retry_with() -> typing.Generator[int, None, None]: + nonlocal attempts + attempts += 1 + if attempts < 2: + raise ValueError("retry") + yield 1 + + faster = gen_retry_with.retry_with( # type: ignore[attr-defined] + stop=tenacity.stop_after_attempt(1), + ) + attempts = 0 + with pytest.raises(RetryError): + list(faster()) + + if __name__ == "__main__": unittest.main()