Skip to content

Commit

Permalink
Refactor threadsafe_async_cache() to guard against unfinalized coro…
Browse files Browse the repository at this point in the history
…utines
  • Loading branch information
aiudirog committed Oct 17, 2023
1 parent 02fad19 commit 6f53be8
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 41 deletions.
103 changes: 65 additions & 38 deletions aiuti/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from asyncio import (
QueueEmpty as AioQueueEmpty,
TimeoutError as AioTimeoutError,
AbstractEventLoop as Loop,
run_coroutine_threadsafe as run_coro_ts,
)
from itertools import islice
from threading import Lock
Expand All @@ -52,8 +54,6 @@
A_contra = TypeVar('A_contra', contravariant=True)
R_co = TypeVar('R_co', covariant=True)

Loop = aio.AbstractEventLoop

logger = logging.getLogger(__name__)

_DONE = object()
Expand Down Expand Up @@ -395,6 +395,8 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:
# Get the key from the input arguments
key = args, frozenset(kwargs.items())

running_loop = aio.get_running_loop()

while True:
# Avoid locking during this first check for performance
try: # to get the value from the cache
Expand All @@ -405,52 +407,77 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:
# Need to calculate and cache the value once

with event_making_lock:
try: # verify nothing cached while waiting for lock
return _cache[key]
except KeyError:
pass

try:
# Try to get the loop + event of the loop currently
# caching the value
loop, event = events[key]
caching_loop, event = events[key]
if (caching_loop.is_closed()
or not caching_loop.is_running()):
raise KeyError # Invalidate loop
except KeyError:
try: # verify nothing cached while waiting for lock
return _cache[key]
except KeyError:
pass
# No existing event -> this task is going to cache
# the value and provide an event for others to wait
loop = aio.get_running_loop()
caching_loop = aio.get_running_loop()
event = aio.Event()
events[key] = loop, event
waiting = False
events[key] = caching_loop, event
do_caching = True
else:
waiting = True # Need to wait for other loop
do_caching = False # Need to wait for other loop

if waiting: # Wait for other task, maybe across threads
if aio.get_running_loop() is loop:
await event.wait()
if do_caching: # No other task to wait for, cache the value
try:
result = await _func(*args, **kwargs)
except Exception:
raise # Bubble any errors without caching
else:
_cache[key] = result # Cache for other tasks
finally:
with event_making_lock:
# Wake up any waiting tasks
event.set()
# Allow garbage collection and/or another loop
# to take over caching if this failed
del events[key]
return result

# Need to wait for another task, possibly across threads

wait_event: Awaitable[bool] = event.wait()
if running_loop is not caching_loop:
try:
wait_fut = run_coro_ts(wait_event, caching_loop)
except RuntimeError: # caching loop most likely closed
continue # loop around and try again
wait_event = aio.wrap_future(wait_fut)

# Wrap anything waiting for the event in a long timeout just
# to ensure nothing hangs completely if the original task is
# lost and somehow never sets the event
waiter = aio.create_task(aio.wait_for(wait_event, 60))

# wait_for will swallow CancelledErrors if the wrapped task
# is already finished. This can cause confusion and missed
# timeouts when there are multiple nested wait_fors.
# Shielding and handling the cancellation directly avoids
# this confusion by preventing this wait_for from seeing
# the outer CancelledError directly.
try:
await aio.shield(waiter)
except aio.TimeoutError: # Possible original task lost?
pass # Need to loop around and check
except (Exception, aio.CancelledError):
if not waiter.done():
waiter.cancel()
try:
await aio.wrap_future(
aio.run_coroutine_threadsafe(event.wait(), loop),
)
except RuntimeError: # Target loop most likely closed
await waiter
except aio.CancelledError:
pass
continue

# First to arrive, cache the value
try:
result = await _func(*args, **kwargs)
except BaseException:
raise # Bubble any encountered errors without caching
else:
_cache[key] = result # Cache for other tasks
finally:
with event_making_lock:
# Wake up any waiting tasks
event.set()
# Allow garbage collection and/or another loop to
# take over caching if this failed
del events[key]

return result
raise

return _wrapper # type: ignore[return-value]

Expand Down Expand Up @@ -1241,8 +1268,8 @@ async def run_aw_threadsafe(aw: Awaitable[T], loop: Loop) -> T:
This does not handle event loop conflicts.
Use :func:`ensure_aw` for that.
"""
fut = aio.run_coroutine_threadsafe(_aw_to_coro(aw), loop)
return await aio.wrap_future(fut)
coro = aw if aio.iscoroutine(aw) else _aw_to_coro(aw)
return await aio.wrap_future(run_coro_ts(coro, loop))


async def _aw_to_coro(aw: Awaitable[T]) -> T:
Expand Down
9 changes: 6 additions & 3 deletions tests/test_threadsafe_async_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ def test_canceling() -> None:

@threadsafe_async_cache
async def func() -> None:
await aio.sleep(0.1)
await aio.sleep(1)

async def main() -> None:
async def impatient():
with pytest.raises(aio.TimeoutError):
await aio.wait_for(aio.create_task(func()), 0.0001)
await aio.wait_for(func(), 0.0001)

async def main() -> None:
await aio.gather(*(impatient() for _ in range(10)))

with ThreadPoolExecutor(4) as pool:
for t in [pool.submit(aio.run, main()) for _ in range(10)]:
Expand Down

0 comments on commit 6f53be8

Please sign in to comment.