Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ def copy(
enabled=_first_set(enabled, self.enabled),
)

def __getstate__(self) -> dict[str, t.Any]:
# Exclude threading.local which cannot be pickled
return {k: v for k, v in self.__dict__.items() if k != "_local"}

def __setstate__(self, state: dict[str, t.Any]) -> None:
self.__dict__.update(state)
self._local = threading.local()

def __str__(self) -> str:
return self._name if self._name is not None else "<unknown>"

Expand Down
52 changes: 24 additions & 28 deletions tenacity/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def __init__(
| tuple[type[BaseException], ...] = Exception,
) -> None:
self.exception_types = exception_types
super().__init__(lambda e: isinstance(e, exception_types))
super().__init__(self._check)

def _check(self, e: BaseException) -> bool:
return isinstance(e, self.exception_types)


class retry_if_not_exception_type(retry_if_exception):
Expand All @@ -120,7 +123,10 @@ def __init__(
| tuple[type[BaseException], ...] = Exception,
) -> None:
self.exception_types = exception_types
super().__init__(lambda e: not isinstance(e, exception_types))
super().__init__(self._check)

def _check(self, e: BaseException) -> bool:
return not isinstance(e, self.exception_types)


class retry_unless_exception_type(retry_if_exception):
Expand All @@ -132,7 +138,10 @@ def __init__(
| tuple[type[BaseException], ...] = Exception,
) -> None:
self.exception_types = exception_types
super().__init__(lambda e: not isinstance(e, exception_types))
super().__init__(self._check)

def _check(self, e: BaseException) -> bool:
return not isinstance(e, self.exception_types)

def __call__(self, retry_state: "RetryCallState") -> bool:
if retry_state.outcome is None:
Expand Down Expand Up @@ -219,40 +228,27 @@ def __init__(
f"{self.__class__.__name__}() takes either 'message' or 'match', not both"
)

# set predicate
if message:

def message_fnc(exception: BaseException) -> bool:
return message == str(exception)

predicate = message_fnc
elif match:
prog = re.compile(match)

def match_fnc(exception: BaseException) -> bool:
return bool(prog.match(str(exception)))

predicate = match_fnc
else:
if not message and not match:
raise TypeError(
f"{self.__class__.__name__}() missing 1 required argument 'message' or 'match'"
)

super().__init__(predicate)
self.message = message
self.match = re.compile(match) if match else None
super().__init__(self._check)

def _check(self, exception: BaseException) -> bool:
if self.message:
return self.message == str(exception)
assert self.match is not None
return bool(self.match.match(str(exception)))


class retry_if_not_exception_message(retry_if_exception_message):
"""Retries until an exception message equals or matches."""

def __init__(
self,
message: str | None = None,
match: None | str | re.Pattern[str] = None,
) -> None:
super().__init__(message, match)
# invert predicate
if_predicate = self.predicate
self.predicate = lambda *args_, **kwargs_: not if_predicate(*args_, **kwargs_)
def _check(self, exception: BaseException) -> bool:
return not super()._check(exception)

def __call__(self, retry_state: "RetryCallState") -> bool:
if retry_state.outcome is None:
Expand Down
57 changes: 57 additions & 0 deletions tests/test_tenacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import contextlib
import datetime
import logging
import pickle
import re
import time
import typing
Expand Down Expand Up @@ -2009,5 +2010,61 @@ def test_decorated_retry_with(self, mock_sleep: typing.Any) -> None:
assert mock_sleep.call_count == 1


class TestPickle(unittest.TestCase):
def test_retrying_picklable(self) -> None:
"""Retrying objects can be pickled for multiprocessing support."""
retrying = Retrying(stop=tenacity.stop_after_attempt(3))
pickled = pickle.dumps(retrying)
restored = pickle.loads(pickled)
assert isinstance(restored, Retrying)
assert isinstance(restored.stop, tenacity.stop_after_attempt)

def test_retrying_picklable_after_run(self) -> None:
"""Retrying objects can be pickled even after being used."""
retrying = Retrying(stop=tenacity.stop_after_attempt(3))
# Access statistics to populate _local
_ = retrying.statistics
pickled = pickle.dumps(retrying)
restored = pickle.loads(pickled)
assert isinstance(restored, Retrying)
# Statistics should be reset on the restored object
assert restored.statistics == {}

def test_retry_strategies_picklable(self) -> None:
"""All built-in retry strategies can be pickled."""
strategies = [
tenacity.retry_if_exception_type(ValueError),
tenacity.retry_if_not_exception_type(ValueError),
tenacity.retry_if_exception_message(message="fail"),
tenacity.retry_if_exception_message(match="fail.*"),
tenacity.retry_if_not_exception_message(message="fail"),
]
for strategy in strategies:
restored = pickle.loads(pickle.dumps(strategy))
assert type(restored) is type(strategy)

def test_retrying_pickle_round_trip_works(self) -> None:
"""A pickled-then-restored Retrying object retries correctly."""
retrying = Retrying(
stop=tenacity.stop_after_attempt(3),
retry=tenacity.retry_if_exception_type(ValueError),
reraise=True,
)
restored = pickle.loads(pickle.dumps(retrying))

calls = 0

def succeed_on_third() -> str:
nonlocal calls
calls += 1
if calls < 3:
raise ValueError("not yet")
return "ok"

result = restored(succeed_on_third)
assert result == "ok"
assert calls == 3


if __name__ == "__main__":
unittest.main()
Loading