From 0d8f3a481b28af121f6e3d2509ad9adf7d30dd70 Mon Sep 17 00:00:00 2001 From: Julien Danjou Date: Tue, 24 Feb 2026 22:45:16 +0100 Subject: [PATCH] fix: make Retrying and retry strategies picklable for multiprocessing Add __getstate__/__setstate__ to BaseRetrying to exclude the unpicklable threading.local object during serialization. Replace lambdas in retry strategy classes with bound methods so they can be pickled with standard pickle. Fixes #147 Co-Authored-By: Claude Opus 4.6 Change-Id: Ia3cb014c783078a08492014bb7088f10f95e4ae9 --- tenacity/__init__.py | 8 ++++++ tenacity/retry.py | 52 ++++++++++++++++++-------------------- tests/test_tenacity.py | 57 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 28 deletions(-) diff --git a/tenacity/__init__.py b/tenacity/__init__.py index 591f9703..dd4eb480 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -291,6 +291,14 @@ def copy( enabled=_first_set(enabled, self.enabled), ) + def __getstate__(self) -> dict[str, t.Any]: + # Exclude threading.local which cannot be pickled + return {k: v for k, v in self.__dict__.items() if k != "_local"} + + def __setstate__(self, state: dict[str, t.Any]) -> None: + self.__dict__.update(state) + self._local = threading.local() + def __str__(self) -> str: return self._name if self._name is not None else "" diff --git a/tenacity/retry.py b/tenacity/retry.py index 6e9fc3e8..df0cc4d6 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -108,7 +108,10 @@ def __init__( | tuple[type[BaseException], ...] = Exception, ) -> None: self.exception_types = exception_types - super().__init__(lambda e: isinstance(e, exception_types)) + super().__init__(self._check) + + def _check(self, e: BaseException) -> bool: + return isinstance(e, self.exception_types) class retry_if_not_exception_type(retry_if_exception): @@ -120,7 +123,10 @@ def __init__( | tuple[type[BaseException], ...] = Exception, ) -> None: self.exception_types = exception_types - super().__init__(lambda e: not isinstance(e, exception_types)) + super().__init__(self._check) + + def _check(self, e: BaseException) -> bool: + return not isinstance(e, self.exception_types) class retry_unless_exception_type(retry_if_exception): @@ -132,7 +138,10 @@ def __init__( | tuple[type[BaseException], ...] = Exception, ) -> None: self.exception_types = exception_types - super().__init__(lambda e: not isinstance(e, exception_types)) + super().__init__(self._check) + + def _check(self, e: BaseException) -> bool: + return not isinstance(e, self.exception_types) def __call__(self, retry_state: "RetryCallState") -> bool: if retry_state.outcome is None: @@ -219,40 +228,27 @@ def __init__( f"{self.__class__.__name__}() takes either 'message' or 'match', not both" ) - # set predicate - if message: - - def message_fnc(exception: BaseException) -> bool: - return message == str(exception) - - predicate = message_fnc - elif match: - prog = re.compile(match) - - def match_fnc(exception: BaseException) -> bool: - return bool(prog.match(str(exception))) - - predicate = match_fnc - else: + if not message and not match: raise TypeError( f"{self.__class__.__name__}() missing 1 required argument 'message' or 'match'" ) - super().__init__(predicate) + self.message = message + self.match = re.compile(match) if match else None + super().__init__(self._check) + + def _check(self, exception: BaseException) -> bool: + if self.message: + return self.message == str(exception) + assert self.match is not None + return bool(self.match.match(str(exception))) class retry_if_not_exception_message(retry_if_exception_message): """Retries until an exception message equals or matches.""" - def __init__( - self, - message: str | None = None, - match: None | str | re.Pattern[str] = None, - ) -> None: - super().__init__(message, match) - # invert predicate - if_predicate = self.predicate - self.predicate = lambda *args_, **kwargs_: not if_predicate(*args_, **kwargs_) + def _check(self, exception: BaseException) -> bool: + return not super()._check(exception) def __call__(self, retry_state: "RetryCallState") -> bool: if retry_state.outcome is None: diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index 80060a36..34c7adad 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -16,6 +16,7 @@ import contextlib import datetime import logging +import pickle import re import time import typing @@ -2009,5 +2010,61 @@ def test_decorated_retry_with(self, mock_sleep: typing.Any) -> None: assert mock_sleep.call_count == 1 +class TestPickle(unittest.TestCase): + def test_retrying_picklable(self) -> None: + """Retrying objects can be pickled for multiprocessing support.""" + retrying = Retrying(stop=tenacity.stop_after_attempt(3)) + pickled = pickle.dumps(retrying) + restored = pickle.loads(pickled) + assert isinstance(restored, Retrying) + assert isinstance(restored.stop, tenacity.stop_after_attempt) + + def test_retrying_picklable_after_run(self) -> None: + """Retrying objects can be pickled even after being used.""" + retrying = Retrying(stop=tenacity.stop_after_attempt(3)) + # Access statistics to populate _local + _ = retrying.statistics + pickled = pickle.dumps(retrying) + restored = pickle.loads(pickled) + assert isinstance(restored, Retrying) + # Statistics should be reset on the restored object + assert restored.statistics == {} + + def test_retry_strategies_picklable(self) -> None: + """All built-in retry strategies can be pickled.""" + strategies = [ + tenacity.retry_if_exception_type(ValueError), + tenacity.retry_if_not_exception_type(ValueError), + tenacity.retry_if_exception_message(message="fail"), + tenacity.retry_if_exception_message(match="fail.*"), + tenacity.retry_if_not_exception_message(message="fail"), + ] + for strategy in strategies: + restored = pickle.loads(pickle.dumps(strategy)) + assert type(restored) is type(strategy) + + def test_retrying_pickle_round_trip_works(self) -> None: + """A pickled-then-restored Retrying object retries correctly.""" + retrying = Retrying( + stop=tenacity.stop_after_attempt(3), + retry=tenacity.retry_if_exception_type(ValueError), + reraise=True, + ) + restored = pickle.loads(pickle.dumps(retrying)) + + calls = 0 + + def succeed_on_third() -> str: + nonlocal calls + calls += 1 + if calls < 3: + raise ValueError("not yet") + return "ok" + + result = restored(succeed_on_third) + assert result == "ok" + assert calls == 3 + + if __name__ == "__main__": unittest.main()