diff --git a/async_retrying.py b/async_retrying.py index 0adcca9..910a829 100644 --- a/async_retrying.py +++ b/async_retrying.py @@ -49,7 +49,7 @@ def callback(attempt, exc, args, kwargs, delay=None, *, loop): callback.delay = 0.5 -def retry( +def factory( fn=None, *, attempts=3, @@ -193,3 +193,29 @@ def wrapped(*fn_args, **fn_kwargs): return wrapper(fn) raise NotImplementedError + + +class retry: + def __init__(self, fn=None, *args, **kwargs): + self._fn = fn + self._wrapper = factory(self._fn, *args, **kwargs) + + def __enter__(self): + return self._wrapper + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + @asyncio.coroutine + def __aenter__(self): + return self._wrapper + + @asyncio.coroutine + def __aexit__(self, exc_type, exc_val, exc_tb): + return self.__exit__ + + def __call__(self, fn=None): + if not self._fn and callable(fn): + return self._wrapper(fn) + + return self._wrapper() diff --git a/tests/test_context_manager.py b/tests/test_context_manager.py new file mode 100644 index 0000000..5c8e6f3 --- /dev/null +++ b/tests/test_context_manager.py @@ -0,0 +1,29 @@ +import asyncio +from functools import partial + +import pytest + +from async_retrying import retry + +@pytest.mark.run_loop +@asyncio.coroutine +def test_context_manager(loop): + counter = 0 + + @asyncio.coroutine + def fn(): + nonlocal counter + + counter += 1 + + if counter == 1: + raise RuntimeError + + return counter + + with retry(fn, loop=loop) as context: + yield from context() + + ret = yield from partial(fn)() + + assert ret == counter