From 27790ec0f423274901064f38abc63550e964b038 Mon Sep 17 00:00:00 2001 From: Hasier Date: Mon, 29 Jan 2024 14:49:21 +0000 Subject: [PATCH] Support async actions --- tenacity/_asyncio.py | 49 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py index d901cbd1..9d418a8c 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/_asyncio.py @@ -16,6 +16,7 @@ # limitations under the License. import functools +import inspect import sys import typing as t from asyncio import sleep @@ -30,6 +31,15 @@ WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) +def is_coroutine_callable(call: t.Callable[..., t.Any]) -> bool: + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): + return False + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.iscoroutinefunction(dunder_call) + + class AsyncRetrying(BaseRetrying): sleep: t.Callable[[float], t.Awaitable[t.Any]] @@ -44,7 +54,7 @@ async def __call__( # type: ignore[override] retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) while True: - do = self.iter(retry_state=retry_state) + do = await self.iter(retry_state=retry_state) if isinstance(do, DoAttempt): try: result = await fn(*args, **kwargs) @@ -58,6 +68,41 @@ async def __call__( # type: ignore[override] else: return do # type: ignore[no-any-return] + @classmethod + def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: + if is_coroutine_callable(fn): + return fn + + async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: + return fn(*args, **kwargs) + + return inner + + def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: + self.iter_state["actions"].append(self._wrap_action_func(fn)) + + async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + self.iter_state["retry_run_result"] = await self._wrap_action_func(self.retry)(retry_state) + + async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + if self.wait: + sleep = await self._wrap_action_func(self.wait)(retry_state) + else: + sleep = 0.0 + + retry_state.upcoming_sleep = sleep + + async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start + self.iter_state["stop_run_result"] = await self._wrap_action_func(self.stop)(retry_state) + + async def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa: A003 + self._begin_iter(retry_state) + result = None + for action in self.iter_state["actions"]: + result = await action(retry_state) + return result + def __iter__(self) -> t.Generator[AttemptManager, None, None]: raise TypeError("AsyncRetrying object is not iterable") @@ -68,7 +113,7 @@ def __aiter__(self) -> "AsyncRetrying": async def __anext__(self) -> AttemptManager: while True: - do = self.iter(retry_state=self._retry_state) + do = await self.iter(retry_state=self._retry_state) if do is None: raise StopAsyncIteration elif isinstance(do, DoAttempt):