diff --git a/src/cachetools_async/decorators.py b/src/cachetools_async/decorators.py index 281db95..6c29438 100644 --- a/src/cachetools_async/decorators.py +++ b/src/cachetools_async/decorators.py @@ -1,4 +1,4 @@ -from asyncio import Future, Task, get_event_loop, shield +from asyncio import Future, Task, get_running_loop, shield from functools import update_wrapper from inspect import iscoroutinefunction from typing import ( @@ -72,11 +72,17 @@ async def wrapper(*args, **kwargs): if future.exception() is None: return future.result() + # Evict failed futures so they don't occupy cache slots + try: + del cache[k] + except KeyError: + pass + coro = fn(*args, **kwargs) - loop = get_event_loop() + loop = get_running_loop() - # Crete a task that tracks the coroutine execution + # Create a task that tracks the coroutine execution task = loop.create_task(coro) # Create a future and then tie the future and task together @@ -139,11 +145,17 @@ async def wrapper(self, *args, **kwargs): if future.exception() is None: return future.result() + # Evict failed futures so they don't occupy cache slots + try: + del c[k] + except KeyError: + pass + coro = method(self, *args, **kwargs) - loop = get_event_loop() + loop = get_running_loop() - # Crete a task that tracks the coroutine execution + # Create a task that tracks the coroutine execution task = loop.create_task(coro) # Create a future and then tie the future and task together diff --git a/tests/test_cached.py b/tests/test_cached.py index dbebc83..4a14ef2 100644 --- a/tests/test_cached.py +++ b/tests/test_cached.py @@ -126,6 +126,28 @@ async def test_does_not_cache_exceptions(self): assert await decorated_fn() == "example" + async def test_failed_futures_are_evicted_from_cache(self): + cache = {} + mock = AsyncMock() + + mock.side_effect = [ + TypeError(), + "example", + ] + + decorated_fn = cachetools_async.cached(cache)(mock) + + with pytest.raises(TypeError): + await decorated_fn("foo") + + # The failed future should have been evicted on the next call + await decorated_fn("foo") + assert len(cache) == 1 + + # The successful result should now be cached + future = cache[list(cache.keys())[0]] + assert future.result() == "example" + async def test_cache_clear_evicts_everything(self): mock = AsyncMock() diff --git a/tests/test_cachedmethod.py b/tests/test_cachedmethod.py index 87b8886..30ed515 100644 --- a/tests/test_cachedmethod.py +++ b/tests/test_cachedmethod.py @@ -144,6 +144,30 @@ async def test_does_not_cache_exceptions(self, mock_resolver): assert await decorated_fn(mock) == "example" + async def test_failed_futures_are_evicted_from_cache(self): + cache = {} + mock_resolver = MagicMock() + mock_resolver.return_value = cache + + mock = AsyncMock() + mock.func.side_effect = [ + TypeError(), + "example", + ] + + decorated_fn = cachetools_async.cachedmethod(mock_resolver)(mock.func) + + with pytest.raises(TypeError): + await decorated_fn(mock, "foo") + + # The failed future should have been evicted on the next call + await decorated_fn(mock, "foo") + assert len(cache) == 1 + + # The successful result should now be cached + future = cache[list(cache.keys())[0]] + assert future.result() == "example" + async def test_cache_clear_evicts_everything(self, mock_resolver): mock = AsyncMock() mock.return_value = "bar"