Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def iter_state(self) -> IterState:
self._local.iter_state = IterState()
return self._local.iter_state # type: ignore[no-any-return]

def wraps(self, f: WrappedFn) -> WrappedFn:
def wraps(self, f: t.Callable[P, R]) -> "_RetryDecorated[P, R]":
"""Wrap a function for retrying.

:param f: A function to wrap for retrying.
Expand All @@ -339,15 +339,15 @@ def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any:
wrapped_f.statistics = copy.statistics # type: ignore[attr-defined]
return copy(f, *args, **kw)

def retry_with(*args: t.Any, **kwargs: t.Any) -> WrappedFn:
def retry_with(*args: t.Any, **kwargs: t.Any) -> "_RetryDecorated[P, R]":
return self.copy(*args, **kwargs).wraps(f)

# Preserve attributes
wrapped_f.retry = self # type: ignore[attr-defined]
wrapped_f.retry_with = retry_with # type: ignore[attr-defined]
wrapped_f.statistics = {} # type: ignore[attr-defined]

return wrapped_f # type: ignore[return-value]
return t.cast("_RetryDecorated[P, R]", wrapped_f)

def begin(self) -> None:
self.statistics.clear()
Expand Down Expand Up @@ -604,25 +604,41 @@ def __repr__(self) -> str:
return f"<{clsname} {id(self)}: attempt #{self.attempt_number}; slept for {slept}; last result: {result}>"


class _RetryDecorated(t.Protocol[P, R]):
"""Protocol for functions decorated with @retry.

Provides the original callable signature plus retry control attributes.
"""

retry: "BaseRetrying"
statistics: dict[str, t.Any]

def retry_with(self, *args: t.Any, **kwargs: t.Any) -> "_RetryDecorated[P, R]": ...

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...


class _AsyncRetryDecorator(t.Protocol):
@t.overload
def __call__(
self, fn: "t.Callable[P, types.CoroutineType[t.Any, t.Any, R]]"
) -> "t.Callable[P, types.CoroutineType[t.Any, t.Any, R]]": ...
) -> "_RetryDecorated[P, types.CoroutineType[t.Any, t.Any, R]]": ...
@t.overload
def __call__(
self, fn: t.Callable[P, t.Coroutine[t.Any, t.Any, R]]
) -> t.Callable[P, t.Coroutine[t.Any, t.Any, R]]: ...
) -> "_RetryDecorated[P, t.Coroutine[t.Any, t.Any, R]]": ...
@t.overload
def __call__(
self, fn: t.Callable[P, t.Awaitable[R]]
) -> t.Callable[P, t.Awaitable[R]]: ...
) -> "_RetryDecorated[P, t.Awaitable[R]]": ...
@t.overload
def __call__(self, fn: t.Callable[P, R]) -> t.Callable[P, t.Awaitable[R]]: ...
def __call__(
self, fn: t.Callable[P, R]
) -> "_RetryDecorated[P, t.Awaitable[R]]": ...


@t.overload
def retry(func: WrappedFn) -> WrappedFn: ...
def retry(func: t.Callable[P, R]) -> _RetryDecorated[P, R]: ...


@t.overload
Expand Down Expand Up @@ -656,7 +672,7 @@ def retry(
retry_error_cls: type["RetryError"] = RetryError,
retry_error_callback: t.Callable[["RetryCallState"], t.Any | t.Awaitable[t.Any]]
| None = None,
) -> t.Callable[[WrappedFn], WrappedFn]: ...
) -> t.Callable[[t.Callable[P, R]], _RetryDecorated[P, R]]: ...


def retry(*dargs: t.Any, **dkw: t.Any) -> t.Any:
Expand All @@ -669,7 +685,7 @@ def retry(*dargs: t.Any, **dkw: t.Any) -> t.Any:
if len(dargs) == 1 and callable(dargs[0]):
return retry()(dargs[0])

def wrap(f: WrappedFn) -> WrappedFn:
def wrap(f: t.Callable[P, R]) -> _RetryDecorated[P, R]:
if isinstance(f, retry_base):
warnings.warn(
f"Got retry_base instance ({f.__class__.__name__}) as callable argument, "
Expand Down
9 changes: 6 additions & 3 deletions tenacity/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DoSleep,
RetryCallState,
RetryError,
_RetryDecorated,
_utils,
after_nothing,
before_nothing,
Expand All @@ -48,6 +49,8 @@

WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]])
P = t.ParamSpec("P")
R = t.TypeVar("R")


def _portable_async_sleep(seconds: float) -> t.Awaitable[None]:
Expand Down Expand Up @@ -178,7 +181,7 @@ async def __anext__(self) -> AttemptManager:
else:
raise StopAsyncIteration

def wraps(self, fn: WrappedFn) -> WrappedFn:
def wraps(self, fn: t.Callable[P, R]) -> _RetryDecorated[P, R]:
wrapped = super().wraps(fn)
# Ensure wrapper is recognized as a coroutine function.

Expand All @@ -190,14 +193,14 @@ async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
# calling the same wrapped functions multiple times in the same stack
copy = self.copy()
async_wrapped.statistics = copy.statistics # type: ignore[attr-defined]
return await copy(fn, *args, **kwargs)
return await copy(fn, *args, **kwargs) # type: ignore[type-var]

# Preserve attributes
async_wrapped.retry = self # type: ignore[attr-defined]
async_wrapped.retry_with = wrapped.retry_with # type: ignore[attr-defined]
async_wrapped.statistics = {} # type: ignore[attr-defined]

return async_wrapped # type: ignore[return-value]
return t.cast("_RetryDecorated[P, R]", async_wrapped)


__all__ = [
Expand Down
21 changes: 11 additions & 10 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ async def function_with_kwdefaults(*, a: int = 1) -> int:
wrapped_kwdefaults_function = retrying.wraps(function_with_kwdefaults)

self.assertEqual(
function_with_defaults.__defaults__, wrapped_defaults_function.__defaults__
function_with_defaults.__defaults__,
wrapped_defaults_function.__defaults__, # type: ignore[attr-defined]
)
self.assertEqual(
function_with_kwdefaults.__kwdefaults__,
wrapped_kwdefaults_function.__kwdefaults__,
wrapped_kwdefaults_function.__kwdefaults__, # type: ignore[attr-defined]
)

@asynctest
Expand All @@ -142,8 +143,8 @@ def after(retry_state: RetryCallState) -> None:
thing2 = NoIOErrorAfterCount(3)

await asyncio.gather(
_retryable_coroutine.retry_with(after=after)(thing1), # type: ignore[attr-defined]
_retryable_coroutine.retry_with(after=after)(thing2), # type: ignore[attr-defined]
_retryable_coroutine.retry_with(after=after)(thing1),
_retryable_coroutine.retry_with(after=after)(thing2),
)

# There's no waiting on retry, only a wait in the coroutine, so the
Expand Down Expand Up @@ -429,16 +430,16 @@ async def test_retry_function_attributes(self) -> None:
"start_time": mock.ANY,
}
self.assertEqual(
_retryable_coroutine_with_2_attempts.statistics, # type: ignore[attr-defined]
_retryable_coroutine_with_2_attempts.statistics,
expected_stats,
)
self.assertEqual(
_retryable_coroutine_with_2_attempts.retry.statistics, # type: ignore[attr-defined]
_retryable_coroutine_with_2_attempts.retry.statistics,
{},
)

with mock.patch.object(
_retryable_coroutine_with_2_attempts.retry, # type: ignore[attr-defined]
_retryable_coroutine_with_2_attempts.retry,
"stop",
tenacity.stop_after_attempt(1),
):
Expand All @@ -454,12 +455,12 @@ async def test_retry_function_attributes(self) -> None:
"start_time": mock.ANY,
}
self.assertEqual(
_retryable_coroutine_with_2_attempts.statistics, # type: ignore[attr-defined]
_retryable_coroutine_with_2_attempts.statistics,
expected_stats,
)
self.assertEqual(exc.last_attempt.attempt_number, 1)
self.assertEqual(
_retryable_coroutine_with_2_attempts.retry.statistics, # type: ignore[attr-defined]
_retryable_coroutine_with_2_attempts.retry.statistics,
{},
)
else:
Expand Down Expand Up @@ -493,7 +494,7 @@ async def test_sync_function_with_async_sleep(self) -> None:
def sync_function() -> Any:
return thing.go()

result = await sync_function() # type: ignore[no-untyped-call]
result = await sync_function()
assert result is True
assert mock_sleep.await_count == 2

Expand Down
47 changes: 24 additions & 23 deletions tests/test_tenacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ def test_retry_until_exception_of_type_attempt_number(self) -> None:
_retryable_test_with_unless_exception_type_name(NameErrorUntilCount(5))
)
except NameError as e:
s = _retryable_test_with_unless_exception_type_name.statistics # type: ignore[attr-defined]
s = _retryable_test_with_unless_exception_type_name.statistics
self.assertTrue(s["attempt_number"] == 6)
print(e)
else:
Expand All @@ -1192,7 +1192,7 @@ def test_retry_until_exception_of_type_no_type(self) -> None:
)
)
except NameError as e:
s = _retryable_test_with_unless_exception_type_no_input.statistics # type: ignore[attr-defined]
s = _retryable_test_with_unless_exception_type_no_input.statistics
self.assertTrue(s["attempt_number"] == 6)
print(e)
else:
Expand All @@ -1215,7 +1215,7 @@ def test_retry_if_exception_message(self) -> None:
_retryable_test_if_exception_message_message(NoCustomErrorAfterCount(3))
)
except CustomError:
print(_retryable_test_if_exception_message_message.statistics) # type: ignore[attr-defined]
print(_retryable_test_if_exception_message_message.statistics)
self.fail("CustomError should've been retried from errormessage")

def test_retry_if_not_exception_message(self) -> None:
Expand All @@ -1226,7 +1226,7 @@ def test_retry_if_not_exception_message(self) -> None:
)
)
except CustomError:
s = _retryable_test_if_not_exception_message_message.statistics # type: ignore[attr-defined]
s = _retryable_test_if_not_exception_message_message.statistics
self.assertTrue(s["attempt_number"] == 1)

def test_retry_if_not_exception_message_delay(self) -> None:
Expand All @@ -1235,7 +1235,7 @@ def test_retry_if_not_exception_message_delay(self) -> None:
_retryable_test_not_exception_message_delay(NameErrorUntilCount(3))
)
except NameError:
s = _retryable_test_not_exception_message_delay.statistics # type: ignore[attr-defined]
s = _retryable_test_not_exception_message_delay.statistics
print(s["attempt_number"])
self.assertTrue(s["attempt_number"] == 4)

Expand All @@ -1255,7 +1255,7 @@ def test_retry_if_not_exception_message_match(self) -> None:
)
)
except CustomError:
s = _retryable_test_if_not_exception_message_message.statistics # type: ignore[attr-defined]
s = _retryable_test_if_not_exception_message_message.statistics
self.assertTrue(s["attempt_number"] == 1)

def test_retry_if_exception_cause_type(self) -> None:
Expand Down Expand Up @@ -1283,17 +1283,18 @@ def function_with_kwdefaults(*, a: int = 1) -> int:
wrapped_kwdefaults_function = retrying.wraps(function_with_kwdefaults)

self.assertEqual(
function_with_defaults.__defaults__, wrapped_defaults_function.__defaults__
function_with_defaults.__defaults__,
wrapped_defaults_function.__defaults__, # type: ignore[attr-defined]
)
self.assertEqual(
function_with_kwdefaults.__kwdefaults__,
wrapped_kwdefaults_function.__kwdefaults__,
wrapped_kwdefaults_function.__kwdefaults__, # type: ignore[attr-defined]
)

def test_defaults(self) -> None:
self.assertTrue(_retryable_default(NoNameErrorAfterCount(5))) # type: ignore[no-untyped-call]
self.assertTrue(_retryable_default(NoNameErrorAfterCount(5)))
self.assertTrue(_retryable_default_f(NoNameErrorAfterCount(5)))
self.assertTrue(_retryable_default(NoCustomErrorAfterCount(5))) # type: ignore[no-untyped-call]
self.assertTrue(_retryable_default(NoCustomErrorAfterCount(5)))
self.assertTrue(_retryable_default_f(NoCustomErrorAfterCount(5)))

def test_retry_function_object(self) -> None:
Expand Down Expand Up @@ -1329,11 +1330,11 @@ def test_retry_function_attributes(self) -> None:
"idle_for": mock.ANY,
"start_time": mock.ANY,
}
self.assertEqual(_retryable_test_with_stop.statistics, expected_stats) # type: ignore[attr-defined]
self.assertEqual(_retryable_test_with_stop.retry.statistics, {}) # type: ignore[attr-defined]
self.assertEqual(_retryable_test_with_stop.statistics, expected_stats)
self.assertEqual(_retryable_test_with_stop.retry.statistics, {})

with mock.patch.object(
_retryable_test_with_stop.retry, # type: ignore[attr-defined]
_retryable_test_with_stop.retry,
"stop",
tenacity.stop_after_attempt(1),
):
Expand All @@ -1346,25 +1347,25 @@ def test_retry_function_attributes(self) -> None:
"idle_for": mock.ANY,
"start_time": mock.ANY,
}
self.assertEqual(_retryable_test_with_stop.statistics, expected_stats) # type: ignore[attr-defined]
self.assertEqual(_retryable_test_with_stop.statistics, expected_stats)
self.assertEqual(exc.last_attempt.attempt_number, 1)
self.assertEqual(_retryable_test_with_stop.retry.statistics, {}) # type: ignore[attr-defined]
self.assertEqual(_retryable_test_with_stop.retry.statistics, {})
else:
self.fail("RetryError should have been raised after 1 attempt")


class TestRetryWith:
def test_redefine_wait(self) -> None:
start = current_time_ms()
result = _retryable_test_with_wait.retry_with(wait=tenacity.wait_fixed(0.1))( # type: ignore[attr-defined]
result = _retryable_test_with_wait.retry_with(wait=tenacity.wait_fixed(0.1))(
NoneReturnUntilAfterCount(5)
)
t = current_time_ms() - start
assert t >= 500
assert result is True

def test_redefine_stop(self) -> None:
result = _retryable_test_with_stop.retry_with( # type: ignore[attr-defined]
result = _retryable_test_with_stop.retry_with(
stop=tenacity.stop_after_attempt(5)
)(NoneReturnUntilAfterCount(4))
assert result is True
Expand All @@ -1375,7 +1376,7 @@ def _retryable() -> None:
raise Exception("raised for test purposes")

with pytest.raises(Exception) as exc_ctx:
_retryable.retry_with(stop=tenacity.stop_after_attempt(2))() # type: ignore[attr-defined]
_retryable.retry_with(stop=tenacity.stop_after_attempt(2))()

assert exc_ctx.type is ValueError, "Should remap to specific exception type"

Expand All @@ -1387,7 +1388,7 @@ def return_text(retry_state: RetryCallState) -> str:
def _retryable() -> None:
raise Exception("raised for test purposes")

result = _retryable.retry_with(stop=tenacity.stop_after_attempt(5))() # type: ignore[attr-defined]
result = _retryable.retry_with(stop=tenacity.stop_after_attempt(5))()
assert result == "Calling _retryable keeps raising errors after 5 attempts"


Expand Down Expand Up @@ -1619,19 +1620,19 @@ def test_stats(self) -> None:
def _foobar() -> int:
return 42

self.assertEqual({}, _foobar.statistics) # type: ignore[attr-defined]
self.assertEqual({}, _foobar.statistics)
_foobar()
self.assertEqual(1, _foobar.statistics["attempt_number"]) # type: ignore[attr-defined]
self.assertEqual(1, _foobar.statistics["attempt_number"])

def test_stats_failing(self) -> None:
@retry(stop=tenacity.stop_after_attempt(2))
def _foobar() -> None:
raise ValueError(42)

self.assertEqual({}, _foobar.statistics) # type: ignore[attr-defined]
self.assertEqual({}, _foobar.statistics)
with contextlib.suppress(Exception):
_foobar()
self.assertEqual(2, _foobar.statistics["attempt_number"]) # type: ignore[attr-defined]
self.assertEqual(2, _foobar.statistics["attempt_number"])


class TestRetryErrorCallback(unittest.TestCase):
Expand Down
Loading