From c25b94659bb6f4b067f33963e8fb5b021170b7c6 Mon Sep 17 00:00:00 2001 From: Julien Danjou Date: Mon, 23 Feb 2026 16:47:03 +0100 Subject: [PATCH] fix: expose .retry, .statistics, .retry_with to type checkers Add _RetryDecorated Protocol that combines the original callable signature (via ParamSpec) with the retry control attributes. Update all retry() overloads, _AsyncRetryDecorator, and wraps() methods to return _RetryDecorated[P, R] instead of bare WrappedFn. This lets mypy see .retry, .statistics, and .retry_with on decorated functions without # type: ignore, fixing a long-standing typing gap. Removes ~30 now-unnecessary type: ignore comments from tests. Closes #346 Co-Authored-By: Claude Opus 4.6 Change-Id: I8c2bc7b49a0cb51175a8e0bf33d9742f10bc49a0 --- tenacity/__init__.py | 36 +++++++++++++++++++-------- tenacity/asyncio/__init__.py | 9 ++++--- tests/test_asyncio.py | 21 ++++++++-------- tests/test_tenacity.py | 47 ++++++++++++++++++------------------ 4 files changed, 67 insertions(+), 46 deletions(-) diff --git a/tenacity/__init__.py b/tenacity/__init__.py index f2ec648e..c43475c4 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -323,7 +323,7 @@ def iter_state(self) -> IterState: self._local.iter_state = IterState() return self._local.iter_state # type: ignore[no-any-return] - def wraps(self, f: WrappedFn) -> WrappedFn: + def wraps(self, f: t.Callable[P, R]) -> "_RetryDecorated[P, R]": """Wrap a function for retrying. :param f: A function to wrap for retrying. @@ -339,7 +339,7 @@ def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any: wrapped_f.statistics = copy.statistics # type: ignore[attr-defined] return copy(f, *args, **kw) - def retry_with(*args: t.Any, **kwargs: t.Any) -> WrappedFn: + def retry_with(*args: t.Any, **kwargs: t.Any) -> "_RetryDecorated[P, R]": return self.copy(*args, **kwargs).wraps(f) # Preserve attributes @@ -347,7 +347,7 @@ def retry_with(*args: t.Any, **kwargs: t.Any) -> WrappedFn: wrapped_f.retry_with = retry_with # type: ignore[attr-defined] wrapped_f.statistics = {} # type: ignore[attr-defined] - return wrapped_f # type: ignore[return-value] + return t.cast("_RetryDecorated[P, R]", wrapped_f) def begin(self) -> None: self.statistics.clear() @@ -604,25 +604,41 @@ def __repr__(self) -> str: return f"<{clsname} {id(self)}: attempt #{self.attempt_number}; slept for {slept}; last result: {result}>" +class _RetryDecorated(t.Protocol[P, R]): + """Protocol for functions decorated with @retry. + + Provides the original callable signature plus retry control attributes. + """ + + retry: "BaseRetrying" + statistics: dict[str, t.Any] + + def retry_with(self, *args: t.Any, **kwargs: t.Any) -> "_RetryDecorated[P, R]": ... + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... + + class _AsyncRetryDecorator(t.Protocol): @t.overload def __call__( self, fn: "t.Callable[P, types.CoroutineType[t.Any, t.Any, R]]" - ) -> "t.Callable[P, types.CoroutineType[t.Any, t.Any, R]]": ... + ) -> "_RetryDecorated[P, types.CoroutineType[t.Any, t.Any, R]]": ... @t.overload def __call__( self, fn: t.Callable[P, t.Coroutine[t.Any, t.Any, R]] - ) -> t.Callable[P, t.Coroutine[t.Any, t.Any, R]]: ... + ) -> "_RetryDecorated[P, t.Coroutine[t.Any, t.Any, R]]": ... @t.overload def __call__( self, fn: t.Callable[P, t.Awaitable[R]] - ) -> t.Callable[P, t.Awaitable[R]]: ... + ) -> "_RetryDecorated[P, t.Awaitable[R]]": ... @t.overload - def __call__(self, fn: t.Callable[P, R]) -> t.Callable[P, t.Awaitable[R]]: ... + def __call__( + self, fn: t.Callable[P, R] + ) -> "_RetryDecorated[P, t.Awaitable[R]]": ... @t.overload -def retry(func: WrappedFn) -> WrappedFn: ... +def retry(func: t.Callable[P, R]) -> _RetryDecorated[P, R]: ... @t.overload @@ -656,7 +672,7 @@ def retry( retry_error_cls: type["RetryError"] = RetryError, retry_error_callback: t.Callable[["RetryCallState"], t.Any | t.Awaitable[t.Any]] | None = None, -) -> t.Callable[[WrappedFn], WrappedFn]: ... +) -> t.Callable[[t.Callable[P, R]], _RetryDecorated[P, R]]: ... def retry(*dargs: t.Any, **dkw: t.Any) -> t.Any: @@ -669,7 +685,7 @@ def retry(*dargs: t.Any, **dkw: t.Any) -> t.Any: if len(dargs) == 1 and callable(dargs[0]): return retry()(dargs[0]) - def wrap(f: WrappedFn) -> WrappedFn: + def wrap(f: t.Callable[P, R]) -> _RetryDecorated[P, R]: if isinstance(f, retry_base): warnings.warn( f"Got retry_base instance ({f.__class__.__name__}) as callable argument, " diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py index 6d882c46..85b216ab 100644 --- a/tenacity/asyncio/__init__.py +++ b/tenacity/asyncio/__init__.py @@ -27,6 +27,7 @@ DoSleep, RetryCallState, RetryError, + _RetryDecorated, _utils, after_nothing, before_nothing, @@ -48,6 +49,8 @@ WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) +P = t.ParamSpec("P") +R = t.TypeVar("R") def _portable_async_sleep(seconds: float) -> t.Awaitable[None]: @@ -178,7 +181,7 @@ async def __anext__(self) -> AttemptManager: else: raise StopAsyncIteration - def wraps(self, fn: WrappedFn) -> WrappedFn: + def wraps(self, fn: t.Callable[P, R]) -> _RetryDecorated[P, R]: wrapped = super().wraps(fn) # Ensure wrapper is recognized as a coroutine function. @@ -190,14 +193,14 @@ async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: # calling the same wrapped functions multiple times in the same stack copy = self.copy() async_wrapped.statistics = copy.statistics # type: ignore[attr-defined] - return await copy(fn, *args, **kwargs) + return await copy(fn, *args, **kwargs) # type: ignore[type-var] # Preserve attributes async_wrapped.retry = self # type: ignore[attr-defined] async_wrapped.retry_with = wrapped.retry_with # type: ignore[attr-defined] async_wrapped.statistics = {} # type: ignore[attr-defined] - return async_wrapped # type: ignore[return-value] + return t.cast("_RetryDecorated[P, R]", async_wrapped) __all__ = [ diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 4447f087..599c4f26 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -124,11 +124,12 @@ async def function_with_kwdefaults(*, a: int = 1) -> int: wrapped_kwdefaults_function = retrying.wraps(function_with_kwdefaults) self.assertEqual( - function_with_defaults.__defaults__, wrapped_defaults_function.__defaults__ + function_with_defaults.__defaults__, + wrapped_defaults_function.__defaults__, # type: ignore[attr-defined] ) self.assertEqual( function_with_kwdefaults.__kwdefaults__, - wrapped_kwdefaults_function.__kwdefaults__, + wrapped_kwdefaults_function.__kwdefaults__, # type: ignore[attr-defined] ) @asynctest @@ -142,8 +143,8 @@ def after(retry_state: RetryCallState) -> None: thing2 = NoIOErrorAfterCount(3) await asyncio.gather( - _retryable_coroutine.retry_with(after=after)(thing1), # type: ignore[attr-defined] - _retryable_coroutine.retry_with(after=after)(thing2), # type: ignore[attr-defined] + _retryable_coroutine.retry_with(after=after)(thing1), + _retryable_coroutine.retry_with(after=after)(thing2), ) # There's no waiting on retry, only a wait in the coroutine, so the @@ -429,16 +430,16 @@ async def test_retry_function_attributes(self) -> None: "start_time": mock.ANY, } self.assertEqual( - _retryable_coroutine_with_2_attempts.statistics, # type: ignore[attr-defined] + _retryable_coroutine_with_2_attempts.statistics, expected_stats, ) self.assertEqual( - _retryable_coroutine_with_2_attempts.retry.statistics, # type: ignore[attr-defined] + _retryable_coroutine_with_2_attempts.retry.statistics, {}, ) with mock.patch.object( - _retryable_coroutine_with_2_attempts.retry, # type: ignore[attr-defined] + _retryable_coroutine_with_2_attempts.retry, "stop", tenacity.stop_after_attempt(1), ): @@ -454,12 +455,12 @@ async def test_retry_function_attributes(self) -> None: "start_time": mock.ANY, } self.assertEqual( - _retryable_coroutine_with_2_attempts.statistics, # type: ignore[attr-defined] + _retryable_coroutine_with_2_attempts.statistics, expected_stats, ) self.assertEqual(exc.last_attempt.attempt_number, 1) self.assertEqual( - _retryable_coroutine_with_2_attempts.retry.statistics, # type: ignore[attr-defined] + _retryable_coroutine_with_2_attempts.retry.statistics, {}, ) else: @@ -493,7 +494,7 @@ async def test_sync_function_with_async_sleep(self) -> None: def sync_function() -> Any: return thing.go() - result = await sync_function() # type: ignore[no-untyped-call] + result = await sync_function() assert result is True assert mock_sleep.await_count == 2 diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index b7817f7c..017e6831 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -1177,7 +1177,7 @@ def test_retry_until_exception_of_type_attempt_number(self) -> None: _retryable_test_with_unless_exception_type_name(NameErrorUntilCount(5)) ) except NameError as e: - s = _retryable_test_with_unless_exception_type_name.statistics # type: ignore[attr-defined] + s = _retryable_test_with_unless_exception_type_name.statistics self.assertTrue(s["attempt_number"] == 6) print(e) else: @@ -1192,7 +1192,7 @@ def test_retry_until_exception_of_type_no_type(self) -> None: ) ) except NameError as e: - s = _retryable_test_with_unless_exception_type_no_input.statistics # type: ignore[attr-defined] + s = _retryable_test_with_unless_exception_type_no_input.statistics self.assertTrue(s["attempt_number"] == 6) print(e) else: @@ -1215,7 +1215,7 @@ def test_retry_if_exception_message(self) -> None: _retryable_test_if_exception_message_message(NoCustomErrorAfterCount(3)) ) except CustomError: - print(_retryable_test_if_exception_message_message.statistics) # type: ignore[attr-defined] + print(_retryable_test_if_exception_message_message.statistics) self.fail("CustomError should've been retried from errormessage") def test_retry_if_not_exception_message(self) -> None: @@ -1226,7 +1226,7 @@ def test_retry_if_not_exception_message(self) -> None: ) ) except CustomError: - s = _retryable_test_if_not_exception_message_message.statistics # type: ignore[attr-defined] + s = _retryable_test_if_not_exception_message_message.statistics self.assertTrue(s["attempt_number"] == 1) def test_retry_if_not_exception_message_delay(self) -> None: @@ -1235,7 +1235,7 @@ def test_retry_if_not_exception_message_delay(self) -> None: _retryable_test_not_exception_message_delay(NameErrorUntilCount(3)) ) except NameError: - s = _retryable_test_not_exception_message_delay.statistics # type: ignore[attr-defined] + s = _retryable_test_not_exception_message_delay.statistics print(s["attempt_number"]) self.assertTrue(s["attempt_number"] == 4) @@ -1255,7 +1255,7 @@ def test_retry_if_not_exception_message_match(self) -> None: ) ) except CustomError: - s = _retryable_test_if_not_exception_message_message.statistics # type: ignore[attr-defined] + s = _retryable_test_if_not_exception_message_message.statistics self.assertTrue(s["attempt_number"] == 1) def test_retry_if_exception_cause_type(self) -> None: @@ -1283,17 +1283,18 @@ def function_with_kwdefaults(*, a: int = 1) -> int: wrapped_kwdefaults_function = retrying.wraps(function_with_kwdefaults) self.assertEqual( - function_with_defaults.__defaults__, wrapped_defaults_function.__defaults__ + function_with_defaults.__defaults__, + wrapped_defaults_function.__defaults__, # type: ignore[attr-defined] ) self.assertEqual( function_with_kwdefaults.__kwdefaults__, - wrapped_kwdefaults_function.__kwdefaults__, + wrapped_kwdefaults_function.__kwdefaults__, # type: ignore[attr-defined] ) def test_defaults(self) -> None: - self.assertTrue(_retryable_default(NoNameErrorAfterCount(5))) # type: ignore[no-untyped-call] + self.assertTrue(_retryable_default(NoNameErrorAfterCount(5))) self.assertTrue(_retryable_default_f(NoNameErrorAfterCount(5))) - self.assertTrue(_retryable_default(NoCustomErrorAfterCount(5))) # type: ignore[no-untyped-call] + self.assertTrue(_retryable_default(NoCustomErrorAfterCount(5))) self.assertTrue(_retryable_default_f(NoCustomErrorAfterCount(5))) def test_retry_function_object(self) -> None: @@ -1329,11 +1330,11 @@ def test_retry_function_attributes(self) -> None: "idle_for": mock.ANY, "start_time": mock.ANY, } - self.assertEqual(_retryable_test_with_stop.statistics, expected_stats) # type: ignore[attr-defined] - self.assertEqual(_retryable_test_with_stop.retry.statistics, {}) # type: ignore[attr-defined] + self.assertEqual(_retryable_test_with_stop.statistics, expected_stats) + self.assertEqual(_retryable_test_with_stop.retry.statistics, {}) with mock.patch.object( - _retryable_test_with_stop.retry, # type: ignore[attr-defined] + _retryable_test_with_stop.retry, "stop", tenacity.stop_after_attempt(1), ): @@ -1346,9 +1347,9 @@ def test_retry_function_attributes(self) -> None: "idle_for": mock.ANY, "start_time": mock.ANY, } - self.assertEqual(_retryable_test_with_stop.statistics, expected_stats) # type: ignore[attr-defined] + self.assertEqual(_retryable_test_with_stop.statistics, expected_stats) self.assertEqual(exc.last_attempt.attempt_number, 1) - self.assertEqual(_retryable_test_with_stop.retry.statistics, {}) # type: ignore[attr-defined] + self.assertEqual(_retryable_test_with_stop.retry.statistics, {}) else: self.fail("RetryError should have been raised after 1 attempt") @@ -1356,7 +1357,7 @@ def test_retry_function_attributes(self) -> None: class TestRetryWith: def test_redefine_wait(self) -> None: start = current_time_ms() - result = _retryable_test_with_wait.retry_with(wait=tenacity.wait_fixed(0.1))( # type: ignore[attr-defined] + result = _retryable_test_with_wait.retry_with(wait=tenacity.wait_fixed(0.1))( NoneReturnUntilAfterCount(5) ) t = current_time_ms() - start @@ -1364,7 +1365,7 @@ def test_redefine_wait(self) -> None: assert result is True def test_redefine_stop(self) -> None: - result = _retryable_test_with_stop.retry_with( # type: ignore[attr-defined] + result = _retryable_test_with_stop.retry_with( stop=tenacity.stop_after_attempt(5) )(NoneReturnUntilAfterCount(4)) assert result is True @@ -1375,7 +1376,7 @@ def _retryable() -> None: raise Exception("raised for test purposes") with pytest.raises(Exception) as exc_ctx: - _retryable.retry_with(stop=tenacity.stop_after_attempt(2))() # type: ignore[attr-defined] + _retryable.retry_with(stop=tenacity.stop_after_attempt(2))() assert exc_ctx.type is ValueError, "Should remap to specific exception type" @@ -1387,7 +1388,7 @@ def return_text(retry_state: RetryCallState) -> str: def _retryable() -> None: raise Exception("raised for test purposes") - result = _retryable.retry_with(stop=tenacity.stop_after_attempt(5))() # type: ignore[attr-defined] + result = _retryable.retry_with(stop=tenacity.stop_after_attempt(5))() assert result == "Calling _retryable keeps raising errors after 5 attempts" @@ -1619,19 +1620,19 @@ def test_stats(self) -> None: def _foobar() -> int: return 42 - self.assertEqual({}, _foobar.statistics) # type: ignore[attr-defined] + self.assertEqual({}, _foobar.statistics) _foobar() - self.assertEqual(1, _foobar.statistics["attempt_number"]) # type: ignore[attr-defined] + self.assertEqual(1, _foobar.statistics["attempt_number"]) def test_stats_failing(self) -> None: @retry(stop=tenacity.stop_after_attempt(2)) def _foobar() -> None: raise ValueError(42) - self.assertEqual({}, _foobar.statistics) # type: ignore[attr-defined] + self.assertEqual({}, _foobar.statistics) with contextlib.suppress(Exception): _foobar() - self.assertEqual(2, _foobar.statistics["attempt_number"]) # type: ignore[attr-defined] + self.assertEqual(2, _foobar.statistics["attempt_number"]) class TestRetryErrorCallback(unittest.TestCase):