diff --git a/doc/source/index.rst b/doc/source/index.rst index 2d1e9f8e..e9c8c199 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -618,6 +618,38 @@ statistics should be read from the function `statistics` attribute. ... +Disabling Retries +~~~~~~~~~~~~~~~~~ + +You can disable retrying entirely by passing ``enabled=False``. When disabled, +the decorated function is called directly without any retry logic. This is +useful during development or testing when you want fast feedback on failures: + +.. testcode:: + + import os + + @retry( + enabled=os.getenv("ENABLE_RETRIES", "1") != "0", + stop=stop_after_attempt(5), + wait=wait_fixed(1), + ) + def call_api(): + pass # your code here + + call_api() + +You can also use ``retry_with`` to disable retries on a per-call basis: + +.. testcode:: + + @retry(stop=stop_after_attempt(5)) + def call_api(): + pass # your code here + + # In tests: + call_api.retry_with(enabled=False)() + Retrying code block ~~~~~~~~~~~~~~~~~~~ diff --git a/tenacity/__init__.py b/tenacity/__init__.py index 57af68a4..591f9703 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -240,6 +240,7 @@ def __init__( retry_error_cls: type[RetryError] = RetryError, retry_error_callback: t.Callable[["RetryCallState"], t.Any] | None = None, name: str | None = None, + enabled: bool = True, ): self.sleep = sleep self.stop = stop @@ -253,6 +254,7 @@ def __init__( self.retry_error_cls = retry_error_cls self.retry_error_callback = retry_error_callback self._name = name + self.enabled = enabled def copy( self, @@ -269,6 +271,7 @@ def copy( | None | object = _unset, name: str | None | object = _unset, + enabled: bool | object = _unset, ) -> "Self": """Copy this object with some parameters changed if needed.""" return self.__class__( @@ -285,6 +288,7 @@ def copy( retry_error_callback, self.retry_error_callback ), name=_first_set(name, self._name), + enabled=_first_set(enabled, self.enabled), ) def __str__(self) -> str: @@ -344,6 +348,8 @@ def wraps(self, f: t.Callable[P, R]) -> "_RetryDecorated[P, R]": f, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__") ) def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any: + if not self.enabled: + return f(*args, **kw) # Always create a copy to prevent overwriting the local contexts when # calling the same wrapped functions multiple times in the same stack copy = self.copy() @@ -667,6 +673,7 @@ def retry( retry_error_cls: type["RetryError"] = ..., retry_error_callback: t.Callable[["RetryCallState"], t.Any | t.Awaitable[t.Any]] | None = ..., + enabled: bool = ..., ) -> _AsyncRetryDecorator: ... @@ -684,6 +691,7 @@ def retry( retry_error_cls: type["RetryError"] = RetryError, retry_error_callback: t.Callable[["RetryCallState"], t.Any | t.Awaitable[t.Any]] | None = None, + enabled: bool = True, ) -> t.Callable[[t.Callable[P, R]], _RetryDecorated[P, R]]: ... diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py index 85b216ab..0ff48855 100644 --- a/tenacity/asyncio/__init__.py +++ b/tenacity/asyncio/__init__.py @@ -91,6 +91,7 @@ def __init__( retry_error_callback: t.Callable[["RetryCallState"], t.Any | t.Awaitable[t.Any]] | None = None, name: str | None = None, + enabled: bool = True, ) -> None: super().__init__( sleep=sleep, # type: ignore[arg-type] @@ -104,6 +105,7 @@ def __init__( retry_error_cls=retry_error_cls, retry_error_callback=retry_error_callback, name=name, + enabled=enabled, ) async def __call__( # type: ignore[override] @@ -189,6 +191,8 @@ def wraps(self, fn: t.Callable[P, R]) -> _RetryDecorated[P, R]: fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__") ) async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: + if not self.enabled: + return await fn(*args, **kwargs) # type: ignore[misc] # Always create a copy to prevent overwriting the local contexts when # calling the same wrapped functions multiple times in the same stack copy = self.copy() diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 0cdb8ba6..749c8834 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -160,6 +160,23 @@ def after(retry_state: RetryCallState) -> None: assert list(attempt_nos2) == [1, 2, 3] +class TestAsyncEnabled(unittest.TestCase): + @asynctest + async def test_enabled_false_skips_retry(self) -> None: + """When enabled=False, async function is called directly without retrying.""" + call_count = 0 + + @retry(enabled=False, stop=stop_after_attempt(3)) + async def always_fails() -> None: + nonlocal call_count + call_count += 1 + raise ValueError("fail") + + with pytest.raises(ValueError, match="fail"): + await always_fails() + assert call_count == 1 + + @unittest.skipIf(not have_trio, "trio not installed") class TestTrio(unittest.TestCase): def test_trio_basic(self) -> None: diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index 5a90a467..b72831eb 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -1405,6 +1405,64 @@ def succeeds_first_try() -> bool: assert succeeds_first_try.statistics["delay_since_first_attempt"] == 0 +class TestEnabled: + def test_enabled_false_skips_retry(self) -> None: + """When enabled=False, the function is called directly without retrying.""" + call_count = 0 + + @retry(enabled=False, stop=tenacity.stop_after_attempt(3)) + def always_fails() -> None: + nonlocal call_count + call_count += 1 + raise ValueError("fail") + + with pytest.raises(ValueError, match="fail"): + always_fails() + assert call_count == 1 + + def test_enabled_false_preserves_attributes(self) -> None: + """When enabled=False, .retry, .retry_with, .statistics are still available.""" + + @retry(enabled=False, stop=tenacity.stop_after_attempt(3)) + def my_func() -> str: + return "ok" + + assert hasattr(my_func, "retry") + assert hasattr(my_func, "retry_with") + assert hasattr(my_func, "statistics") + assert my_func() == "ok" + + def test_enabled_false_via_retry_with(self) -> None: + """retry_with(enabled=False) disables retrying.""" + call_count = 0 + + @retry(stop=tenacity.stop_after_attempt(3)) + def always_fails() -> None: + nonlocal call_count + call_count += 1 + raise ValueError("fail") + + disabled = always_fails.retry_with(enabled=False) + with pytest.raises(ValueError, match="fail"): + disabled() + assert call_count == 1 + + def test_enabled_true_retries_normally(self) -> None: + """When enabled=True (default), retrying works as usual.""" + call_count = 0 + + @retry(enabled=True, stop=tenacity.stop_after_attempt(3), reraise=True) + def fails_twice() -> bool: + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ValueError("fail") + return True + + assert fails_twice() is True + assert call_count == 3 + + class TestRetryWith: def test_redefine_wait(self) -> None: start = current_time_ms()