From 3cb6cc3a09efba91343fea9c4747aed54e4290a4 Mon Sep 17 00:00:00 2001 From: Rudy Nurhadi Date: Fri, 23 Feb 2024 18:43:11 +0700 Subject: [PATCH] Add asyncio.Lock mutex to _request With this change if there is simultaneous request to the same url at the same time, it will only fetch the url once and the rest is feeded from cache --- aiohttp_client_cache/backends/base.py | 24 ++++++--- aiohttp_client_cache/session.py | 78 ++++++++++++++++----------- test/unit/test_session.py | 18 +++---- 3 files changed, 71 insertions(+), 49 deletions(-) diff --git a/aiohttp_client_cache/backends/base.py b/aiohttp_client_cache/backends/base.py index 44582bc9..a3eb4918 100644 --- a/aiohttp_client_cache/backends/base.py +++ b/aiohttp_client_cache/backends/base.py @@ -103,18 +103,18 @@ async def is_cacheable( logger.debug(f'Pre-cache checks for response from {response.url}: {cache_criteria}') # type: ignore return not any(cache_criteria.values()) - async def request( + async def create_cache_actions( self, - method: str, + key: str, url: StrOrURL, expire_after: ExpirationTime = None, refresh: bool = False, **kwargs, - ) -> Tuple[Optional[CachedResponse], CacheActions]: - """Fetch a cached response based on request info + ): + """Create cache actions based on request info Args: - method: HTTP method + key: key from create_key function url: Request URL expire_after: Expiration time to set only for this request; overrides ``CachedSession.expire_after``, and accepts all the same values. @@ -122,8 +122,7 @@ async def request( (e.g., a "soft refresh", like F5 in a browser) kwargs: All other request arguments """ - key = self.create_key(method, url, **kwargs) - actions = CacheActions.from_request( + return CacheActions.from_request( key, url=url, request_expire_after=expire_after, @@ -135,9 +134,18 @@ async def request( **kwargs, ) + async def request( + self, + actions: CacheActions, + ) -> Optional[CachedResponse]: + """Fetch a cached response based on cache actions + + Args: + actions: CacheActions from create_cache_actions function + """ # Skip reading from the cache, if specified by request headers response = None if actions.skip_read else await self.get_response(actions.key) - return response, actions + return response async def get_response(self, key: str) -> Optional[CachedResponse]: """Fetch a cached response based on a cache key""" diff --git a/aiohttp_client_cache/session.py b/aiohttp_client_cache/session.py index c25e6c51..8dbb4ae7 100644 --- a/aiohttp_client_cache/session.py +++ b/aiohttp_client_cache/session.py @@ -1,6 +1,7 @@ """Core functions for cache configuration""" import warnings -from contextlib import asynccontextmanager +from asyncio import Lock +from contextlib import asynccontextmanager, AsyncExitStack from logging import getLogger from typing import TYPE_CHECKING, Optional, Tuple @@ -32,6 +33,7 @@ def __init__( **kwargs, ): self.cache = cache or CacheBackend() + self._locks: dict[str, Lock] = {} # Pass along any valid kwargs for ClientSession (or custom session superclass) session_kwargs = get_valid_kwargs(super().__init__, {**kwargs, 'base_url': base_url}) @@ -48,40 +50,52 @@ async def _request( ) -> AnyResponse: """Wrapper around :py:meth:`.SessionClient._request` that adds caching""" # Attempt to fetch cached response - response, actions = await self.cache.request( - method, str_or_url, expire_after=expire_after, refresh=refresh, **kwargs + key = self.cache.create_key(method, str_or_url) + actions = self.cache.create_cache_actions( + key, str_or_url, expire_after=expire_after, refresh=refresh, **kwargs ) - def restore_cookies(r): - self.cookie_jar.update_cookies(r.cookies or {}, r.url) - for redirect in r.history: - self.cookie_jar.update_cookies(redirect.cookies or {}, redirect.url) - - if actions.revalidate and response: - from_cache, new_response = await self._refresh_cached_response( - method, str_or_url, response, actions, **kwargs - ) - if not from_cache: - return set_response_defaults(new_response) - else: - restore_cookies(new_response) - return new_response - - # Restore any cached cookies to the session - if response: - restore_cookies(response) - return response - # If the response was missing or expired, send and cache a new request - else: - if actions.skip_read: - logger.debug(f'Reading from cache was skipped; making request to {str_or_url}') + lock = AsyncExitStack() + if not actions.skip_read: + try: + lock = self._locks[key] + except KeyError: + self._locks[key] = Lock() + lock = self._locks[key] + + async with lock: + response = await self.cache.request(actions) + + def restore_cookies(r): + self.cookie_jar.update_cookies(r.cookies or {}, r.url) + for redirect in r.history: + self.cookie_jar.update_cookies(redirect.cookies or {}, redirect.url) + + if actions.revalidate and response: + from_cache, new_response = await self._refresh_cached_response( + method, str_or_url, response, actions, **kwargs + ) + if not from_cache: + return set_response_defaults(new_response) + else: + restore_cookies(new_response) + return new_response + + # Restore any cached cookies to the session + if response: + restore_cookies(response) + return response + # If the response was missing or expired, send and cache a new request else: - logger.debug(f'Cached response not found; making request to {str_or_url}') - new_response = await super()._request(method, str_or_url, **kwargs) - actions.update_from_response(new_response) - if await self.cache.is_cacheable(new_response, actions): - await self.cache.save_response(new_response, actions.key, actions.expires) - return set_response_defaults(new_response) + if actions.skip_read: + logger.debug(f'Reading from cache was skipped; making request to {str_or_url}') + else: + logger.debug(f'Cached response not found; making request to {str_or_url}') + new_response = await super()._request(method, str_or_url, **kwargs) + actions.update_from_response(new_response) + if await self.cache.is_cacheable(new_response, actions): + await self.cache.save_response(new_response, actions.key, actions.expires) + return set_response_defaults(new_response) async def _refresh_cached_response( self, diff --git a/test/unit/test_session.py b/test/unit/test_session.py index cc16e940..660870f8 100644 --- a/test/unit/test_session.py +++ b/test/unit/test_session.py @@ -57,7 +57,7 @@ async def test_session__init_posarg(): async def test_session__cache_hit(mock_request): cache = MagicMock(spec=CacheBackend) response = AsyncMock(is_expired=False, url=URL('https://test.com')) - cache.request.return_value = response, CacheActions() + cache.request.return_value = response async with CachedSession(cache=cache) as session: await session.get('http://test.url') @@ -68,7 +68,7 @@ async def test_session__cache_hit(mock_request): @patch.object(ClientSession, '_request') async def test_session__cache_expired_or_invalid(mock_request): cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: await session.get('http://test.url') @@ -79,7 +79,7 @@ async def test_session__cache_expired_or_invalid(mock_request): @patch.object(ClientSession, '_request') async def test_session__cache_miss(mock_request): cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: await session.get('http://test.url') @@ -90,7 +90,7 @@ async def test_session__cache_miss(mock_request): @patch.object(ClientSession, '_request') async def test_session__request_expire_after(mock_request): cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: await session.get('http://test.url', expire_after=10) @@ -102,7 +102,7 @@ async def test_session__request_expire_after(mock_request): @patch.object(ClientSession, '_request') async def test_session__default_attrs(mock_request): cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: response = await session.get('http://test.url') @@ -123,7 +123,7 @@ async def test_session__default_attrs(mock_request): async def test_all_param_types(mock_request, params) -> None: """Ensure that CachedSession.request() acceepts all the same parameter types as aiohttp""" cache = MagicMock(spec=CacheBackend) - cache.request.return_value = None, CacheActions() + cache.request.return_value = None async with CachedSession(cache=cache) as session: response = await session.get('http://test.url', params=params) @@ -139,7 +139,7 @@ async def test_session__cookies(mock_request): url=URL('https://test.com'), cookies=SimpleCookie({'test_cookie': 'value'}), ) - cache.request.return_value = response, CacheActions() + cache.request.return_value = response async with CachedSession(cache=cache) as session: session.cookie_jar.clear() @@ -154,7 +154,7 @@ async def test_session__empty_cookies(mock_request): """Previous versions didn't set cookies if they were empty. Just make sure it doesn't explode.""" cache = MagicMock(spec=CacheBackend) response = AsyncMock(is_expired=False, url=URL('https://test.com'), cookies=None) - cache.request.return_value = response, CacheActions() + cache.request.return_value = response async with CachedSession(cache=cache) as session: session.cookie_jar.clear() @@ -171,7 +171,7 @@ class CustomSession(CacheMixin, ClientSession): cache = MagicMock(spec=CacheBackend) response = AsyncMock(is_expired=False, url=URL('https://test.com')) - cache.request.return_value = response, CacheActions() + cache.request.return_value = response async with CustomSession(cache=cache) as session: await session.get('http://test.url')