From b0e3b6af9471a3e7ac344fd7d44a2040d6275549 Mon Sep 17 00:00:00 2001 From: Hasier Date: Mon, 29 Jan 2024 14:49:21 +0000 Subject: [PATCH 1/3] Support async actions --- tenacity/_asyncio.py | 49 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py index 16aec620..b9713321 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/_asyncio.py @@ -16,6 +16,7 @@ # limitations under the License. import functools +import inspect import sys import typing as t from asyncio import sleep @@ -30,6 +31,15 @@ WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) +def is_coroutine_callable(call: t.Callable[..., t.Any]) -> bool: + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): + return False + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.iscoroutinefunction(dunder_call) + + class AsyncRetrying(BaseRetrying): sleep: t.Callable[[float], t.Awaitable[t.Any]] @@ -46,7 +56,7 @@ async def __call__( # type: ignore[override] retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) while True: - do = self.iter(retry_state=retry_state) + do = await self.iter(retry_state=retry_state) if isinstance(do, DoAttempt): try: result = await fn(*args, **kwargs) @@ -60,6 +70,41 @@ async def __call__( # type: ignore[override] else: return do # type: ignore[no-any-return] + @classmethod + def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: + if is_coroutine_callable(fn): + return fn + + async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: + return fn(*args, **kwargs) + + return inner + + def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: + self.iter_state["actions"].append(self._wrap_action_func(fn)) + + async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + self.iter_state["retry_run_result"] = await self._wrap_action_func(self.retry)(retry_state) + + async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + if self.wait: + sleep = await self._wrap_action_func(self.wait)(retry_state) + else: + sleep = 0.0 + + retry_state.upcoming_sleep = sleep + + async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start + self.iter_state["stop_run_result"] = await self._wrap_action_func(self.stop)(retry_state) + + async def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa: A003 + self._begin_iter(retry_state) + result = None + for action in self.iter_state["actions"]: + result = await action(retry_state) + return result + def __iter__(self) -> t.Generator[AttemptManager, None, None]: raise TypeError("AsyncRetrying object is not iterable") @@ -70,7 +115,7 @@ def __aiter__(self) -> "AsyncRetrying": async def __anext__(self) -> AttemptManager: while True: - do = self.iter(retry_state=self._retry_state) + do = await self.iter(retry_state=self._retry_state) if do is None: raise StopAsyncIteration elif isinstance(do, DoAttempt): From f8ed7fbfe286ea9029ebef78ed343b2bcdaf65d2 Mon Sep 17 00:00:00 2001 From: Hasier Date: Tue, 6 Feb 2024 11:31:02 +0000 Subject: [PATCH 2/3] Fixes after main rebase --- tenacity/_asyncio.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py index b9713321..b314b146 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/_asyncio.py @@ -81,10 +81,12 @@ async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: return inner def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: - self.iter_state["actions"].append(self._wrap_action_func(fn)) + self.iter_state.actions.append(self._wrap_action_func(fn)) async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override] - self.iter_state["retry_run_result"] = await self._wrap_action_func(self.retry)(retry_state) + self.iter_state.retry_run_result = await self._wrap_action_func(self.retry)( + retry_state + ) async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override] if self.wait: @@ -96,12 +98,16 @@ async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignor async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override] self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start - self.iter_state["stop_run_result"] = await self._wrap_action_func(self.stop)(retry_state) + self.iter_state.stop_run_result = await self._wrap_action_func(self.stop)( + retry_state + ) - async def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa: A003 + async def iter( + self, retry_state: "RetryCallState" + ) -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa: A003 self._begin_iter(retry_state) result = None - for action in self.iter_state["actions"]: + for action in self.iter_state.actions: result = await action(retry_state) return result From 2616c4ad586f91b9cbad4e777f177fcc199911d2 Mon Sep 17 00:00:00 2001 From: Hasier Date: Tue, 6 Feb 2024 12:45:39 +0000 Subject: [PATCH 3/3] Test is_coroutine_callable --- tenacity/_asyncio.py | 13 ++----------- tenacity/_utils.py | 13 ++++++++++++- tests/test_utils.py | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 12 deletions(-) create mode 100644 tests/test_utils.py diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py index b314b146..27c26642 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/_asyncio.py @@ -16,7 +16,6 @@ # limitations under the License. import functools -import inspect import sys import typing as t from asyncio import sleep @@ -26,20 +25,12 @@ from tenacity import DoAttempt from tenacity import DoSleep from tenacity import RetryCallState +from tenacity import _utils WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) -def is_coroutine_callable(call: t.Callable[..., t.Any]) -> bool: - if inspect.isroutine(call): - return inspect.iscoroutinefunction(call) - if inspect.isclass(call): - return False - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return inspect.iscoroutinefunction(dunder_call) - - class AsyncRetrying(BaseRetrying): sleep: t.Callable[[float], t.Awaitable[t.Any]] @@ -72,7 +63,7 @@ async def __call__( # type: ignore[override] @classmethod def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - if is_coroutine_callable(fn): + if _utils.is_coroutine_callable(fn): return fn async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: diff --git a/tenacity/_utils.py b/tenacity/_utils.py index 67ee0dea..4e34115e 100644 --- a/tenacity/_utils.py +++ b/tenacity/_utils.py @@ -13,7 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import functools +import inspect import sys import typing from datetime import timedelta @@ -76,3 +77,13 @@ def to_seconds(time_unit: time_unit_type) -> float: return float( time_unit.total_seconds() if isinstance(time_unit, timedelta) else time_unit ) + + +def is_coroutine_callable(call: typing.Callable[..., typing.Any]) -> bool: + if inspect.isclass(call): + return False + if inspect.iscoroutinefunction(call): + return True + partial_call = isinstance(call, functools.partial) and call.func + dunder_call = partial_call or getattr(call, "__call__", None) + return inspect.iscoroutinefunction(dunder_call) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..ec7b3ee5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,41 @@ +import functools + +from tenacity import _utils + + +def test_is_coroutine_callable() -> None: + async def async_func() -> None: + pass + + def sync_func() -> None: + pass + + class AsyncClass: + async def __call__(self) -> None: + pass + + class SyncClass: + def __call__(self) -> None: + pass + + lambda_fn = lambda: None # noqa: E731 + + partial_async_func = functools.partial(async_func) + partial_sync_func = functools.partial(sync_func) + partial_async_class = functools.partial(AsyncClass().__call__) + partial_sync_class = functools.partial(SyncClass().__call__) + partial_lambda_fn = functools.partial(lambda_fn) + + assert _utils.is_coroutine_callable(async_func) is True + assert _utils.is_coroutine_callable(sync_func) is False + assert _utils.is_coroutine_callable(AsyncClass) is False + assert _utils.is_coroutine_callable(AsyncClass()) is True + assert _utils.is_coroutine_callable(SyncClass) is False + assert _utils.is_coroutine_callable(SyncClass()) is False + assert _utils.is_coroutine_callable(lambda_fn) is False + + assert _utils.is_coroutine_callable(partial_async_func) is True + assert _utils.is_coroutine_callable(partial_sync_func) is False + assert _utils.is_coroutine_callable(partial_async_class) is True + assert _utils.is_coroutine_callable(partial_sync_class) is False + assert _utils.is_coroutine_callable(partial_lambda_fn) is False