Skip to content

Commit

Permalink
Added a cache parameter to threadsafe_async_cache to allow provid…
Browse files Browse the repository at this point in the history
…ing an alternative cache object
  • Loading branch information
aiudirog committed May 22, 2023
1 parent 46c03ba commit 02b44c6
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 15 deletions.
99 changes: 84 additions & 15 deletions aiuti/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from weakref import WeakKeyDictionary as WeakKeyDict, finalize
from time import sleep
from typing import (
Any, AsyncIterable, Awaitable, Callable, Coroutine, Dict,
Any, AsyncIterable, Awaitable, Callable, Coroutine, Dict, MutableMapping,
Generic, Iterable, Iterator, List, Optional, Set,
Tuple, Type, TypeVar, Union,
overload, cast,
Expand Down Expand Up @@ -272,9 +272,31 @@ async def _queue_elements() -> None:
future.result()


_CacheMap = MutableMapping[Tuple[Any, ...], Any]
_AsyncFunc = TypeVar('_AsyncFunc', bound=Callable[..., Awaitable[Any]])


@overload
def threadsafe_async_cache(
func: Callable[..., Awaitable[T]],
) -> Callable[..., Awaitable[T]]:
func: None = None,
*,
cache: Optional[_CacheMap] = None,
) -> Callable[[_AsyncFunc], _AsyncFunc]: ...


@overload
def threadsafe_async_cache(
func: _AsyncFunc,
*,
cache: Optional[_CacheMap] = None,
) -> _AsyncFunc: ...


def threadsafe_async_cache(
func: Optional[_AsyncFunc] = None,
*,
cache: Optional[_CacheMap] = None,
) -> Union[Callable[[_AsyncFunc], _AsyncFunc], _AsyncFunc]:
"""
Simple thread-safe asynchronous cache decorator which ensures that
the decorated/wrapped function is only ever called once for a given
Expand Down Expand Up @@ -315,27 +337,74 @@ def threadsafe_async_cache(
>>> len([x for row in results for x in row])
100
Additionally, a custom mutable mapping can be provided for the cache
to further customizize the implementation. For example, we can use
the `lru-dict <https://pypi.org/project/lru-dict/>`__ library to
limit the size of the cache:
>>> from lru import LRU
>>> cache = LRU(size=3)
>>> @threadsafe_async_cache(cache=cache)
... async def double(x: int) -> int:
... await aio.sleep(0.1)
... print("Doubling:", x)
... return x * 2
>>> run = aio.get_event_loop().run_until_complete
>>> assert run(double(1)) == 2
Doubling: 1
>>> assert run(double(1)) == 2 # Cached
>>> assert run(double(2)) == 4
Doubling: 2
>>> assert run(double(2)) == 4 # Cached
>>> assert run(double(3)) == 6
Doubling: 3
>>> assert run(double(3)) == 6 # Cached
>>> assert run(double(4)) == 8 # Pushes 1 out of the cache
Doubling: 4
>>> assert run(double(4)) == 8 # Cached
>>> assert run(double(1)) == 2 # No longer cached!
Doubling: 1
.. warning::
This cache has no max size and can very easily be the source of
memory leaks. I typically use this to wrap object methods on a
per-instance basis during intilaztion so that the cache only
lives as long as the object.
The default cache is a simple dictionary and as such has no max
size and can very easily be the source of memory leaks.
I typically use this to wrap object methods on a per-instance
basis during intilaztion so that the cache only lives as long
as the object.
"""
cache: Dict[Tuple[Any, ...], T] = {}
if func is None:
return partial( # type: ignore[return-value]
threadsafe_async_cache,
cache=cache,
)

# Avoid type narrowing issues related to:
# https://github.com/python/mypy/issues/13123
_cache: _CacheMap = cache if cache is not None else {}
_func: _AsyncFunc = func
del cache, func

# 1 lock per input key
locks: Dict[Tuple[Any, ...], Lock] = {}
# Ensure thread safety while creating locks
lock_making_lock = Lock()

@wraps(func)
async def _wrapper(*args: Any, **kwargs: Any) -> T:
@wraps(_func)
async def _wrapper(*args: Any, **kwargs: Any) -> Any:
# Get the key from the input arguments
key = args, frozenset(kwargs.items())

while True:
# Avoid locking during this first check for performance
try: # to get the value from the cache
return cache[key]
return _cache[key]
except KeyError:
pass

Expand All @@ -347,19 +416,19 @@ async def _wrapper(*args: Any, **kwargs: Any) -> T:
# Make sure another thread didn't finish caching and
# clean up the lock since the last check
try:
return cache[key]
return _cache[key]
except KeyError:
pass
lock = locks[key] = Lock()

if lock.acquire(blocking=False): # First to arrive, cache
try: # to run the function and get the result
result = await func(*args, **kwargs)
result = await _func(*args, **kwargs)
except BaseException:
raise # Don't cache errors, maybe timeouts or such
else: # Successfully got the result
with lock_making_lock:
cache[key] = result
_cache[key] = result
del locks[key] # Allow lock garbage collection
finally: # Ensure lock is always released
lock.release()
Expand All @@ -372,7 +441,7 @@ async def _wrapper(*args: Any, **kwargs: Any) -> T:
await aio.get_running_loop().run_in_executor(None, lock.acquire)
lock.release()

return _wrapper
return _wrapper # type: ignore[return-value]


_BufferFunc = Callable[[Set[T]], Awaitable[None]]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest
pytest-asyncio>=0.17
requests
httpx
lru-dict
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ tests_require =
pytest-asyncio
requests
httpx
lru-dict

[options.packages.find]
exclude =
Expand Down

0 comments on commit 02b44c6

Please sign in to comment.