diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py index 16aec620..27c26642 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/_asyncio.py @@ -25,6 +25,7 @@ from tenacity import DoAttempt from tenacity import DoSleep from tenacity import RetryCallState +from tenacity import _utils WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) @@ -46,7 +47,7 @@ async def __call__( # type: ignore[override] retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) while True: - do = self.iter(retry_state=retry_state) + do = await self.iter(retry_state=retry_state) if isinstance(do, DoAttempt): try: result = await fn(*args, **kwargs) @@ -60,6 +61,47 @@ async def __call__( # type: ignore[override] else: return do # type: ignore[no-any-return] + @classmethod + def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: + if _utils.is_coroutine_callable(fn): + return fn + + async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: + return fn(*args, **kwargs) + + return inner + + def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: + self.iter_state.actions.append(self._wrap_action_func(fn)) + + async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + self.iter_state.retry_run_result = await self._wrap_action_func(self.retry)( + retry_state + ) + + async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + if self.wait: + sleep = await self._wrap_action_func(self.wait)(retry_state) + else: + sleep = 0.0 + + retry_state.upcoming_sleep = sleep + + async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start + self.iter_state.stop_run_result = await self._wrap_action_func(self.stop)( + retry_state + ) + + async def iter( + self, retry_state: "RetryCallState" + ) -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa: A003 + self._begin_iter(retry_state) + result = None + for action in self.iter_state.actions: + result = await action(retry_state) + return result + def __iter__(self) -> t.Generator[AttemptManager, None, None]: raise TypeError("AsyncRetrying object is not iterable") @@ -70,7 +112,7 @@ def __aiter__(self) -> "AsyncRetrying": async def __anext__(self) -> AttemptManager: while True: - do = self.iter(retry_state=self._retry_state) + do = await self.iter(retry_state=self._retry_state) if do is None: raise StopAsyncIteration elif isinstance(do, DoAttempt): diff --git a/tenacity/_utils.py b/tenacity/_utils.py index 67ee0dea..4e34115e 100644 --- a/tenacity/_utils.py +++ b/tenacity/_utils.py @@ -13,7 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import functools +import inspect import sys import typing from datetime import timedelta @@ -76,3 +77,13 @@ def to_seconds(time_unit: time_unit_type) -> float: return float( time_unit.total_seconds() if isinstance(time_unit, timedelta) else time_unit ) + + +def is_coroutine_callable(call: typing.Callable[..., typing.Any]) -> bool: + if inspect.isclass(call): + return False + if inspect.iscoroutinefunction(call): + return True + partial_call = isinstance(call, functools.partial) and call.func + dunder_call = partial_call or getattr(call, "__call__", None) + return inspect.iscoroutinefunction(dunder_call) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..ec7b3ee5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,41 @@ +import functools + +from tenacity import _utils + + +def test_is_coroutine_callable() -> None: + async def async_func() -> None: + pass + + def sync_func() -> None: + pass + + class AsyncClass: + async def __call__(self) -> None: + pass + + class SyncClass: + def __call__(self) -> None: + pass + + lambda_fn = lambda: None # noqa: E731 + + partial_async_func = functools.partial(async_func) + partial_sync_func = functools.partial(sync_func) + partial_async_class = functools.partial(AsyncClass().__call__) + partial_sync_class = functools.partial(SyncClass().__call__) + partial_lambda_fn = functools.partial(lambda_fn) + + assert _utils.is_coroutine_callable(async_func) is True + assert _utils.is_coroutine_callable(sync_func) is False + assert _utils.is_coroutine_callable(AsyncClass) is False + assert _utils.is_coroutine_callable(AsyncClass()) is True + assert _utils.is_coroutine_callable(SyncClass) is False + assert _utils.is_coroutine_callable(SyncClass()) is False + assert _utils.is_coroutine_callable(lambda_fn) is False + + assert _utils.is_coroutine_callable(partial_async_func) is True + assert _utils.is_coroutine_callable(partial_sync_func) is False + assert _utils.is_coroutine_callable(partial_async_class) is True + assert _utils.is_coroutine_callable(partial_sync_class) is False + assert _utils.is_coroutine_callable(partial_lambda_fn) is False