diff --git a/tenacity/__init__.py b/tenacity/__init__.py index 079740f6..57af68a4 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -214,6 +214,17 @@ def __exit__( self.retry_state.set_result(None) return None + async def __aenter__(self) -> None: + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: t.Optional["types.TracebackType"], + ) -> bool | None: + return self.__exit__(exc_type, exc_value, traceback) + class BaseRetrying(ABC): def __init__( diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 599c4f26..0cdb8ba6 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -190,6 +190,21 @@ async def test_do_max_attempts(self) -> None: assert attempts == 3 + @asynctest + async def test_async_with_attempt_manager(self) -> None: + """AttemptManager supports async with for use inside async for.""" + attempts = 0 + retrying = tasyncio.AsyncRetrying(stop=stop_after_attempt(3)) + try: + async for attempt in retrying: + async with attempt: + attempts += 1 + raise Exception + except RetryError: + pass + + assert attempts == 3 + @asynctest async def test_reraise(self) -> None: class CustomError(Exception):