Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,47 @@ def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any:
wrapped_f.statistics = copy.statistics # type: ignore[attr-defined]
return copy(f, *args, **kw)

@functools.wraps(
f, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
)
def wrapped_gen_f(
*args: t.Any, **kw: t.Any
) -> t.Generator[t.Any, t.Any, t.Any]:
if not self.enabled:
yield from f(*args, **kw)
return
copy = self.copy()
wrapped_gen_f.statistics = copy.statistics # type: ignore[attr-defined]
copy.begin()
retry_state = RetryCallState(retry_object=copy, fn=f, args=args, kwargs=kw)
while True:
do = copy.iter(retry_state=retry_state)
if isinstance(do, DoAttempt):
try:
result = yield from f(*args, **kw)
except GeneratorExit:
raise
except BaseException:
retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type]
else:
retry_state.set_result(result)
elif isinstance(do, DoSleep):
retry_state.prepare_for_next_attempt()
copy.sleep(do)
else:
return do

result_f = wrapped_gen_f if _utils.is_generator_callable(f) else wrapped_f

def retry_with(*args: t.Any, **kwargs: t.Any) -> "_RetryDecorated[P, R]":
return self.copy(*args, **kwargs).wraps(f)

# Preserve attributes
wrapped_f.retry = self # type: ignore[attr-defined]
wrapped_f.retry_with = retry_with # type: ignore[attr-defined]
wrapped_f.statistics = {} # type: ignore[attr-defined]
result_f.retry = self # type: ignore[attr-defined]
result_f.retry_with = retry_with # type: ignore[attr-defined]
result_f.statistics = {} # type: ignore[attr-defined]

return t.cast("_RetryDecorated[P, R]", wrapped_f)
return t.cast("_RetryDecorated[P, R]", result_f)

def begin(self) -> None:
self.statistics.clear()
Expand Down Expand Up @@ -714,8 +746,10 @@ def wrap(f: t.Callable[P, R]) -> _RetryDecorated[P, R]:
)
r: BaseRetrying
sleep = dkw.get("sleep")
if _utils.is_coroutine_callable(f) or (
sleep is not None and _utils.is_coroutine_callable(sleep)
if (
_utils.is_coroutine_callable(f)
or _utils.is_async_gen_callable(f)
or (sleep is not None and _utils.is_coroutine_callable(sleep))
):
r = AsyncRetrying(*dargs, **dkw)
elif (
Expand Down
20 changes: 20 additions & 0 deletions tenacity/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ def is_coroutine_callable(call: typing.Callable[..., typing.Any]) -> bool:
return inspect.iscoroutinefunction(dunder_call)


def is_generator_callable(call: typing.Callable[..., typing.Any]) -> bool:
if inspect.isclass(call):
return False
if inspect.isgeneratorfunction(call):
return True
partial_call = isinstance(call, functools.partial) and call.func
dunder_call = partial_call or getattr(call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)


def is_async_gen_callable(call: typing.Callable[..., typing.Any]) -> bool:
if inspect.isclass(call):
return False
if inspect.isasyncgenfunction(call):
return True
partial_call = isinstance(call, functools.partial) and call.func
dunder_call = partial_call or getattr(call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)


def wrap_to_async_func(
call: typing.Callable[..., typing.Any],
) -> typing.Callable[..., typing.Awaitable[typing.Any]]:
Expand Down
47 changes: 43 additions & 4 deletions tenacity/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,51 @@ async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
async_wrapped.statistics = copy.statistics # type: ignore[attr-defined]
return await copy(fn, *args, **kwargs) # type: ignore[type-var]

@functools.wraps(
fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
)
async def async_wrapped_gen(
*args: t.Any, **kwargs: t.Any
) -> t.AsyncGenerator[t.Any, t.Any]:
if not self.enabled:
async for item in fn(*args, **kwargs): # type: ignore[misc]
yield item
return
copy = self.copy()
async_wrapped_gen.statistics = copy.statistics # type: ignore[attr-defined]
copy.begin()
retry_state = RetryCallState(
retry_object=copy, fn=fn, args=args, kwargs=kwargs
)
while True:
do = await copy.iter(retry_state=retry_state)
if isinstance(do, DoAttempt):
try:
async for item in fn(*args, **kwargs): # type: ignore[misc]
yield item
except GeneratorExit:
raise
except BaseException:
retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type]
else:
retry_state.set_result(None)
elif isinstance(do, DoSleep):
retry_state.prepare_for_next_attempt()
await self.sleep(do) # type: ignore[misc]
else:
return

if _utils.is_async_gen_callable(fn):
result_f = async_wrapped_gen
else:
result_f = async_wrapped

# Preserve attributes
async_wrapped.retry = self # type: ignore[attr-defined]
async_wrapped.retry_with = wrapped.retry_with # type: ignore[attr-defined]
async_wrapped.statistics = {} # type: ignore[attr-defined]
result_f.retry = self # type: ignore[attr-defined]
result_f.retry_with = wrapped.retry_with # type: ignore[attr-defined]
result_f.statistics = {} # type: ignore[attr-defined]

return t.cast("_RetryDecorated[P, R]", async_wrapped)
return t.cast("_RetryDecorated[P, R]", result_f)


__all__ = [
Expand Down
88 changes: 88 additions & 0 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,5 +531,93 @@ def sync_function() -> Any:
assert mock_sleep.await_count == 2


class TestAsyncGeneratorRetry(unittest.TestCase):
@asynctest
async def test_async_generator_retry_on_exception(self) -> None:
attempts = 0

@retry(
stop=stop_after_attempt(3),
retry=tenacity.retry_if_exception_type(ValueError),
reraise=True,
)
async def gen_with_errors() -> Any:
nonlocal attempts
attempts += 1
if attempts < 3:
raise ValueError("not yet")
yield 1
yield 2
yield 3

result = [item async for item in gen_with_errors()]
assert result == [1, 2, 3]
assert attempts == 3

@asynctest
async def test_async_generator_yields_all_values(self) -> None:
@retry
async def simple_gen() -> Any:
yield 10
yield 20
yield 30

result = [item async for item in simple_gen()]
assert result == [10, 20, 30]

@asynctest
async def test_async_generator_stop_after_attempt(self) -> None:
@retry(
stop=stop_after_attempt(2),
retry=tenacity.retry_if_exception_type(RuntimeError),
)
async def always_fails() -> Any:
raise RuntimeError("always")
yield # make it an async generator

with pytest.raises(RetryError):
async for _ in always_fails():
pass

def test_async_generator_has_retry_attributes(self) -> None:
@retry
async def my_gen() -> Any:
yield 1

assert hasattr(my_gen, "retry")
assert hasattr(my_gen, "statistics")
assert hasattr(my_gen, "retry_with")

@asynctest
async def test_async_generator_statistics_updated(self) -> None:
attempts = 0

@retry(
stop=stop_after_attempt(3),
retry=tenacity.retry_if_exception_type(ValueError),
reraise=True,
)
async def gen_stats() -> Any:
nonlocal attempts
attempts += 1
if attempts < 2:
raise ValueError("retry")
yield 42

result = [item async for item in gen_stats()]
assert result == [42]
assert gen_stats.statistics["attempt_number"] == 2 # type: ignore[attr-defined]

@asynctest
async def test_async_generator_enabled_false(self) -> None:
@retry(enabled=False)
async def my_gen() -> Any:
yield 1
yield 2

result = [item async for item in my_gen()]
assert result == [1, 2]


if __name__ == "__main__":
unittest.main()
103 changes: 103 additions & 0 deletions tests/test_tenacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2009,5 +2009,108 @@ def test_decorated_retry_with(self, mock_sleep: typing.Any) -> None:
assert mock_sleep.call_count == 1


class TestGeneratorRetry:
def test_generator_retry_on_exception(self) -> None:
attempts = 0

@retry(
stop=tenacity.stop_after_attempt(3),
retry=tenacity.retry_if_exception_type(ValueError),
reraise=True,
)
def gen_with_errors() -> typing.Generator[int, None, None]:
nonlocal attempts
attempts += 1
if attempts < 3:
raise ValueError("not yet")
yield 1
yield 2
yield 3

result = list(gen_with_errors())
assert result == [1, 2, 3]
assert attempts == 3

def test_generator_yields_all_values(self) -> None:
@retry
def simple_gen() -> typing.Generator[int, None, None]:
yield 10
yield 20
yield 30

result = list(simple_gen())
assert result == [10, 20, 30]

def test_generator_stop_after_attempt(self) -> None:
@retry(
stop=tenacity.stop_after_attempt(2),
retry=tenacity.retry_if_exception_type(RuntimeError),
)
def always_fails() -> typing.Generator[int, None, None]:
raise RuntimeError("always")
yield # make it a generator

with pytest.raises(RetryError):
list(always_fails())

def test_generator_has_retry_attributes(self) -> None:
@retry
def my_gen() -> typing.Generator[int, None, None]:
yield 1

assert hasattr(my_gen, "retry")
assert hasattr(my_gen, "statistics")
assert hasattr(my_gen, "retry_with")

def test_generator_statistics_updated(self) -> None:
attempts = 0

@retry(
stop=tenacity.stop_after_attempt(3),
retry=tenacity.retry_if_exception_type(ValueError),
reraise=True,
)
def gen_stats() -> typing.Generator[int, None, None]:
nonlocal attempts
attempts += 1
if attempts < 2:
raise ValueError("retry")
yield 42

result = list(gen_stats())
assert result == [42]
assert gen_stats.statistics["attempt_number"] == 2 # type: ignore[attr-defined]

def test_generator_enabled_false(self) -> None:
@retry(enabled=False)
def my_gen() -> typing.Generator[int, None, None]:
yield 1
yield 2

result = list(my_gen())
assert result == [1, 2]

def test_generator_retry_with(self) -> None:
attempts = 0

@retry(
stop=tenacity.stop_after_attempt(5),
retry=tenacity.retry_if_exception_type(ValueError),
)
def gen_retry_with() -> typing.Generator[int, None, None]:
nonlocal attempts
attempts += 1
if attempts < 2:
raise ValueError("retry")
yield 1

faster = gen_retry_with.retry_with( # type: ignore[attr-defined]
stop=tenacity.stop_after_attempt(1),
)
attempts = 0
with pytest.raises(RetryError):
list(faster())


if __name__ == "__main__":
unittest.main()
Loading