Skip to content

Commit

Permalink
Add asyncio.Lock mutex to _request
Browse files Browse the repository at this point in the history
With this change if there is simultaneous request to the same url at the same time, it will only fetch the url once and the rest is feeded from cache
  • Loading branch information
rudcode committed Feb 23, 2024
1 parent 3aa2e5e commit 3cb6cc3
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 49 deletions.
24 changes: 16 additions & 8 deletions aiohttp_client_cache/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,27 +103,26 @@ async def is_cacheable(
logger.debug(f'Pre-cache checks for response from {response.url}: {cache_criteria}') # type: ignore
return not any(cache_criteria.values())

async def request(
async def create_cache_actions(
self,
method: str,
key: str,
url: StrOrURL,
expire_after: ExpirationTime = None,
refresh: bool = False,
**kwargs,
) -> Tuple[Optional[CachedResponse], CacheActions]:
"""Fetch a cached response based on request info
):
"""Create cache actions based on request info
Args:
method: HTTP method
key: key from create_key function
url: Request URL
expire_after: Expiration time to set only for this request; overrides
``CachedSession.expire_after``, and accepts all the same values.
refresh: Revalidate with the server before using a cached response, and refresh if needed
(e.g., a "soft refresh", like F5 in a browser)
kwargs: All other request arguments
"""
key = self.create_key(method, url, **kwargs)
actions = CacheActions.from_request(
return CacheActions.from_request(
key,
url=url,
request_expire_after=expire_after,
Expand All @@ -135,9 +134,18 @@ async def request(
**kwargs,
)

async def request(
self,
actions: CacheActions,
) -> Optional[CachedResponse]:
"""Fetch a cached response based on cache actions
Args:
actions: CacheActions from create_cache_actions function
"""
# Skip reading from the cache, if specified by request headers
response = None if actions.skip_read else await self.get_response(actions.key)
return response, actions
return response

async def get_response(self, key: str) -> Optional[CachedResponse]:
"""Fetch a cached response based on a cache key"""
Expand Down
78 changes: 46 additions & 32 deletions aiohttp_client_cache/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Core functions for cache configuration"""
import warnings
from contextlib import asynccontextmanager
from asyncio import Lock
from contextlib import asynccontextmanager, AsyncExitStack
from logging import getLogger
from typing import TYPE_CHECKING, Optional, Tuple

Expand Down Expand Up @@ -32,6 +33,7 @@ def __init__(
**kwargs,
):
self.cache = cache or CacheBackend()
self._locks: dict[str, Lock] = {}

# Pass along any valid kwargs for ClientSession (or custom session superclass)
session_kwargs = get_valid_kwargs(super().__init__, {**kwargs, 'base_url': base_url})
Expand All @@ -48,40 +50,52 @@ async def _request(
) -> AnyResponse:
"""Wrapper around :py:meth:`.SessionClient._request` that adds caching"""
# Attempt to fetch cached response
response, actions = await self.cache.request(
method, str_or_url, expire_after=expire_after, refresh=refresh, **kwargs
key = self.cache.create_key(method, str_or_url)
actions = self.cache.create_cache_actions(
key, str_or_url, expire_after=expire_after, refresh=refresh, **kwargs
)

def restore_cookies(r):
self.cookie_jar.update_cookies(r.cookies or {}, r.url)
for redirect in r.history:
self.cookie_jar.update_cookies(redirect.cookies or {}, redirect.url)

if actions.revalidate and response:
from_cache, new_response = await self._refresh_cached_response(
method, str_or_url, response, actions, **kwargs
)
if not from_cache:
return set_response_defaults(new_response)
else:
restore_cookies(new_response)
return new_response

# Restore any cached cookies to the session
if response:
restore_cookies(response)
return response
# If the response was missing or expired, send and cache a new request
else:
if actions.skip_read:
logger.debug(f'Reading from cache was skipped; making request to {str_or_url}')
lock = AsyncExitStack()
if not actions.skip_read:
try:
lock = self._locks[key]
except KeyError:
self._locks[key] = Lock()
lock = self._locks[key]

async with lock:
response = await self.cache.request(actions)

def restore_cookies(r):
self.cookie_jar.update_cookies(r.cookies or {}, r.url)
for redirect in r.history:
self.cookie_jar.update_cookies(redirect.cookies or {}, redirect.url)

if actions.revalidate and response:
from_cache, new_response = await self._refresh_cached_response(
method, str_or_url, response, actions, **kwargs
)
if not from_cache:
return set_response_defaults(new_response)
else:
restore_cookies(new_response)
return new_response

# Restore any cached cookies to the session
if response:
restore_cookies(response)
return response
# If the response was missing or expired, send and cache a new request
else:
logger.debug(f'Cached response not found; making request to {str_or_url}')
new_response = await super()._request(method, str_or_url, **kwargs)
actions.update_from_response(new_response)
if await self.cache.is_cacheable(new_response, actions):
await self.cache.save_response(new_response, actions.key, actions.expires)
return set_response_defaults(new_response)
if actions.skip_read:
logger.debug(f'Reading from cache was skipped; making request to {str_or_url}')
else:
logger.debug(f'Cached response not found; making request to {str_or_url}')
new_response = await super()._request(method, str_or_url, **kwargs)
actions.update_from_response(new_response)
if await self.cache.is_cacheable(new_response, actions):
await self.cache.save_response(new_response, actions.key, actions.expires)
return set_response_defaults(new_response)

async def _refresh_cached_response(
self,
Expand Down
18 changes: 9 additions & 9 deletions test/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def test_session__init_posarg():
async def test_session__cache_hit(mock_request):
cache = MagicMock(spec=CacheBackend)
response = AsyncMock(is_expired=False, url=URL('https://test.com'))
cache.request.return_value = response, CacheActions()
cache.request.return_value = response

async with CachedSession(cache=cache) as session:
await session.get('http://test.url')
Expand All @@ -68,7 +68,7 @@ async def test_session__cache_hit(mock_request):
@patch.object(ClientSession, '_request')
async def test_session__cache_expired_or_invalid(mock_request):
cache = MagicMock(spec=CacheBackend)
cache.request.return_value = None, CacheActions()
cache.request.return_value = None

async with CachedSession(cache=cache) as session:
await session.get('http://test.url')
Expand All @@ -79,7 +79,7 @@ async def test_session__cache_expired_or_invalid(mock_request):
@patch.object(ClientSession, '_request')
async def test_session__cache_miss(mock_request):
cache = MagicMock(spec=CacheBackend)
cache.request.return_value = None, CacheActions()
cache.request.return_value = None

async with CachedSession(cache=cache) as session:
await session.get('http://test.url')
Expand All @@ -90,7 +90,7 @@ async def test_session__cache_miss(mock_request):
@patch.object(ClientSession, '_request')
async def test_session__request_expire_after(mock_request):
cache = MagicMock(spec=CacheBackend)
cache.request.return_value = None, CacheActions()
cache.request.return_value = None

async with CachedSession(cache=cache) as session:
await session.get('http://test.url', expire_after=10)
Expand All @@ -102,7 +102,7 @@ async def test_session__request_expire_after(mock_request):
@patch.object(ClientSession, '_request')
async def test_session__default_attrs(mock_request):
cache = MagicMock(spec=CacheBackend)
cache.request.return_value = None, CacheActions()
cache.request.return_value = None

async with CachedSession(cache=cache) as session:
response = await session.get('http://test.url')
Expand All @@ -123,7 +123,7 @@ async def test_session__default_attrs(mock_request):
async def test_all_param_types(mock_request, params) -> None:
"""Ensure that CachedSession.request() acceepts all the same parameter types as aiohttp"""
cache = MagicMock(spec=CacheBackend)
cache.request.return_value = None, CacheActions()
cache.request.return_value = None

async with CachedSession(cache=cache) as session:
response = await session.get('http://test.url', params=params)
Expand All @@ -139,7 +139,7 @@ async def test_session__cookies(mock_request):
url=URL('https://test.com'),
cookies=SimpleCookie({'test_cookie': 'value'}),
)
cache.request.return_value = response, CacheActions()
cache.request.return_value = response

async with CachedSession(cache=cache) as session:
session.cookie_jar.clear()
Expand All @@ -154,7 +154,7 @@ async def test_session__empty_cookies(mock_request):
"""Previous versions didn't set cookies if they were empty. Just make sure it doesn't explode."""
cache = MagicMock(spec=CacheBackend)
response = AsyncMock(is_expired=False, url=URL('https://test.com'), cookies=None)
cache.request.return_value = response, CacheActions()
cache.request.return_value = response

async with CachedSession(cache=cache) as session:
session.cookie_jar.clear()
Expand All @@ -171,7 +171,7 @@ class CustomSession(CacheMixin, ClientSession):

cache = MagicMock(spec=CacheBackend)
response = AsyncMock(is_expired=False, url=URL('https://test.com'))
cache.request.return_value = response, CacheActions()
cache.request.return_value = response

async with CustomSession(cache=cache) as session:
await session.get('http://test.url')
Expand Down

0 comments on commit 3cb6cc3

Please sign in to comment.