diff --git a/aiohttp_client_cache/backends/base.py b/aiohttp_client_cache/backends/base.py index 44582bc9..a3eb4918 100644 --- a/aiohttp_client_cache/backends/base.py +++ b/aiohttp_client_cache/backends/base.py @@ -103,18 +103,18 @@ async def is_cacheable( logger.debug(f'Pre-cache checks for response from {response.url}: {cache_criteria}') # type: ignore return not any(cache_criteria.values()) - async def request( + async def create_cache_actions( self, - method: str, + key: str, url: StrOrURL, expire_after: ExpirationTime = None, refresh: bool = False, **kwargs, - ) -> Tuple[Optional[CachedResponse], CacheActions]: - """Fetch a cached response based on request info + ): + """Create cache actions based on request info Args: - method: HTTP method + key: key from create_key function url: Request URL expire_after: Expiration time to set only for this request; overrides ``CachedSession.expire_after``, and accepts all the same values. @@ -122,8 +122,7 @@ async def request( (e.g., a "soft refresh", like F5 in a browser) kwargs: All other request arguments """ - key = self.create_key(method, url, **kwargs) - actions = CacheActions.from_request( + return CacheActions.from_request( key, url=url, request_expire_after=expire_after, @@ -135,9 +134,18 @@ async def request( **kwargs, ) + async def request( + self, + actions: CacheActions, + ) -> Optional[CachedResponse]: + """Fetch a cached response based on cache actions + + Args: + actions: CacheActions from create_cache_actions function + """ # Skip reading from the cache, if specified by request headers response = None if actions.skip_read else await self.get_response(actions.key) - return response, actions + return response async def get_response(self, key: str) -> Optional[CachedResponse]: """Fetch a cached response based on a cache key""" diff --git a/aiohttp_client_cache/session.py b/aiohttp_client_cache/session.py index c25e6c51..8dbb4ae7 100644 --- a/aiohttp_client_cache/session.py +++ b/aiohttp_client_cache/session.py @@ -1,6 +1,7 @@ """Core functions for cache configuration""" import warnings -from contextlib import asynccontextmanager +from asyncio import Lock +from contextlib import asynccontextmanager, AsyncExitStack from logging import getLogger from typing import TYPE_CHECKING, Optional, Tuple @@ -32,6 +33,7 @@ def __init__( **kwargs, ): self.cache = cache or CacheBackend() + self._locks: dict[str, Lock] = {} # Pass along any valid kwargs for ClientSession (or custom session superclass) session_kwargs = get_valid_kwargs(super().__init__, {**kwargs, 'base_url': base_url}) @@ -48,40 +50,52 @@ async def _request( ) -> AnyResponse: """Wrapper around :py:meth:`.SessionClient._request` that adds caching""" # Attempt to fetch cached response - response, actions = await self.cache.request( - method, str_or_url, expire_after=expire_after, refresh=refresh, **kwargs + key = self.cache.create_key(method, str_or_url) + actions = self.cache.create_cache_actions( + key, str_or_url, expire_after=expire_after, refresh=refresh, **kwargs ) - def restore_cookies(r): - self.cookie_jar.update_cookies(r.cookies or {}, r.url) - for redirect in r.history: - self.cookie_jar.update_cookies(redirect.cookies or {}, redirect.url) - - if actions.revalidate and response: - from_cache, new_response = await self._refresh_cached_response( - method, str_or_url, response, actions, **kwargs - ) - if not from_cache: - return set_response_defaults(new_response) - else: - restore_cookies(new_response) - return new_response - - # Restore any cached cookies to the session - if response: - restore_cookies(response) - return response - # If the response was missing or expired, send and cache a new request - else: - if actions.skip_read: - logger.debug(f'Reading from cache was skipped; making request to {str_or_url}') + lock = AsyncExitStack() + if not actions.skip_read: + try: + lock = self._locks[key] + except KeyError: + self._locks[key] = Lock() + lock = self._locks[key] + + async with lock: + response = await self.cache.request(actions) + + def restore_cookies(r): + self.cookie_jar.update_cookies(r.cookies or {}, r.url) + for redirect in r.history: + self.cookie_jar.update_cookies(redirect.cookies or {}, redirect.url) + + if actions.revalidate and response: + from_cache, new_response = await self._refresh_cached_response( + method, str_or_url, response, actions, **kwargs + ) + if not from_cache: + return set_response_defaults(new_response) + else: + restore_cookies(new_response) + return new_response + + # Restore any cached cookies to the session + if response: + restore_cookies(response) + return response + # If the response was missing or expired, send and cache a new request else: - logger.debug(f'Cached response not found; making request to {str_or_url}') - new_response = await super()._request(method, str_or_url, **kwargs) - actions.update_from_response(new_response) - if await self.cache.is_cacheable(new_response, actions): - await self.cache.save_response(new_response, actions.key, actions.expires) - return set_response_defaults(new_response) + if actions.skip_read: + logger.debug(f'Reading from cache was skipped; making request to {str_or_url}') + else: + logger.debug(f'Cached response not found; making request to {str_or_url}') + new_response = await super()._request(method, str_or_url, **kwargs) + actions.update_from_response(new_response) + if await self.cache.is_cacheable(new_response, actions): + await self.cache.save_response(new_response, actions.key, actions.expires) + return set_response_defaults(new_response) async def _refresh_cached_response( self, diff --git a/test/unit/test_session.py b/test/unit/test_session.py index cc16e940..660870f8 100644 --- a/test/unit/test_session.py +++ b/test/unit/test_session.py @@ -57,7 +57,7 @@ async def test_session__init_posarg(): async def test_session__cache_hit(mock_request): cache = MagicMock(spec=CacheBackend) response = AsyncMock(is_expired=False, url=URL('https://test.com')) - cache.request.return_value = response, CacheActions() + cache.request.return_value = response async with CachedSession(cache=cache) as session: await session.get('http://test.url') @@ -68,7 +68,7 @@ async def test_session__cache_hit(mock_request): @patch.object(ClientSession, '_request') async def test_session__cache_expired_or_invalid(mock_request): cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: await session.get('http://test.url') @@ -79,7 +79,7 @@ async def test_session__cache_expired_or_invalid(mock_request): @patch.object(ClientSession, '_request') async def test_session__cache_miss(mock_request): cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: await session.get('http://test.url') @@ -90,7 +90,7 @@ async def test_session__cache_miss(mock_request): @patch.object(ClientSession, '_request') async def test_session__request_expire_after(mock_request): cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: await session.get('http://test.url', expire_after=10) @@ -102,7 +102,7 @@ async def test_session__request_expire_after(mock_request): @patch.object(ClientSession, '_request') async def test_session__default_attrs(mock_request): cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: response = await session.get('http://test.url') @@ -123,7 +123,7 @@ async def test_session__default_attrs(mock_request): async def test_all_param_types(mock_request, params) -> None: """Ensure that CachedSession.request() acceepts all the same parameter types as aiohttp""" cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: response = await session.get('http://test.url', params=params) @@ -139,7 +139,7 @@ async def test_session__cookies(mock_request): url=URL('https://test.com'), cookies=SimpleCookie({'test_cookie': 'value'}), ) - cache.request.return_value = response, CacheActions() + cache.request.return_value = response async with CachedSession(cache=cache) as session: session.cookie_jar.clear() @@ -154,7 +154,7 @@ async def test_session__empty_cookies(mock_request): """Previous versions didn't set cookies if they were empty. Just make sure it doesn't explode.""" cache = MagicMock(spec=CacheBackend) response = AsyncMock(is_expired=False, url=URL('https://test.com'), cookies=None) - cache.request.return_value = response, CacheActions() + cache.request.return_value = response async with CachedSession(cache=cache) as session: session.cookie_jar.clear() @@ -171,7 +171,7 @@ class CustomSession(CacheMixin, ClientSession): cache = MagicMock(spec=CacheBackend) response = AsyncMock(is_expired=False, url=URL('https://test.com')) - cache.request.return_value = response, CacheActions() + cache.request.return_value = response async with CustomSession(cache=cache) as session: await session.get('http://test.url')