diff --git a/README.rst b/README.rst index aca5de5..3d956b7 100644 --- a/README.rst +++ b/README.rst @@ -24,7 +24,19 @@ Usage # it must be done before any import asyncio statement, once per project # best place is __init__.py of You'r application from asyncio_monkey import patch_all # noqa isort:skip - patch_all() # noqa isort:skip + patch_all() + +or call the one you need + +.. code-block:: python + + # it must be done before any import asyncio statement, once per project + # best place is __init__.py of You'r application + import asyncio_monkey # noqa isort:skip + + asyncio_monkey.patch_log_destroy_pending() + asyncio_monkey.patch_get_event_loop() + asyncio_monkey.patch_lock() Features -------- @@ -32,3 +44,5 @@ Features - Disables `get_event_loop` returns currently running loop, even if `MainThread` loop is `None`, useful for Python 3.6.0+ `docs `_ - Disables silent destroying futures inside `asyncio.gather` `source `_ + +- Prevents `asyncio.Lock` deadlock after cancellation `source `_ diff --git a/asyncio_monkey.py b/asyncio_monkey.py index 036a450..dce8529 100644 --- a/asyncio_monkey.py +++ b/asyncio_monkey.py @@ -2,9 +2,23 @@ PY_360 = sys.version_info >= (3, 6, 0) +PY_362 = sys.version_info >= (3, 6, 2) + __version__ = '0.0.4' +def _create_future(*, loop=None): + import asyncio + + if loop is None: + loop = asyncio.get_event_loop() + + try: + return loop.create_future() + except AttributeError: + return asyncio.Future(loop=loop) + + def patch_log_destroy_pending(): import asyncio @@ -49,6 +63,60 @@ def get_event_loop(): asyncio.get_event_loop = asyncio.events.get_event_loop +def patch_lock(): + import asyncio + + if PY_362: + return + + if hasattr(asyncio.locks.Lock, 'patched'): + return + + # Fixes an issue with all Python versions that leaves pending waiters + # without being awakened when the first waiter is canceled. + # Code adapted from the PR https://github.com/python/cpython/pull/1031 + # Waiting once it is merged to make a proper condition to relay on + # the stdlib implementation or this one patched + + class Lock(asyncio.locks.Lock): + patched = True + + @asyncio.coroutine + def acquire(self): + """Acquire a lock. + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._locked and all(w.cancelled() for w in self._waiters): + self._locked = True + return True + + fut = _create_future(loop=self._loop) + + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + except asyncio.futures.CancelledError: + if not self._locked: # pragma: no cover + self._wake_up_first() + raise + finally: + self._waiters.remove(fut) + + def _wake_up_first(self): + """Wake up the first waiter who isn't cancelled.""" + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + break + + asyncio.locks.Lock = Lock + asyncio.Lock = Lock + + def patch_all(): patch_log_destroy_pending() patch_get_event_loop() + patch_lock() diff --git a/tests.py b/tests.py index f00fd35..2a5cbe7 100644 --- a/tests.py +++ b/tests.py @@ -1,11 +1,13 @@ import asyncio + +from asyncio import test_utils from functools import partial from unittest import mock import pytest - from asyncio_monkey import ( - PY_360, patch_all, patch_log_destroy_pending, patch_get_event_loop, + PY_360, PY_362, patch_all, patch_get_event_loop, + patch_lock, patch_log_destroy_pending, ) @@ -104,10 +106,83 @@ def coro(): loop.close() +def test_no_patch_lock(): + if PY_362: + return + + loop = asyncio.new_event_loop() + + assert not hasattr(asyncio.Lock, 'patched') + assert not hasattr(asyncio.locks.Lock, 'patched') + + lock = asyncio.Lock(loop=loop) + + ta = asyncio.Task(lock.acquire(), loop=loop) + test_utils.run_briefly(loop) + assert lock.locked() + + tb = asyncio.Task(lock.acquire(), loop=loop) + test_utils.run_briefly(loop) + assert len(lock._waiters) == 1 + + # Create a second waiter, wake up the first, and cancel it. + # Without the fix, the second was not woken up. + tc = asyncio.Task(lock.acquire(), loop=loop) + lock.release() + tb.cancel() + test_utils.run_briefly(loop) + + assert not lock.locked() + assert ta.done() + assert tb.cancelled() + + loop.close() + + +def test_patch_lock(): + loop = asyncio.new_event_loop() + + assert not hasattr(asyncio.Lock, 'patched') + assert not hasattr(asyncio.locks.Lock, 'patched') + + patch_lock() + patch_lock() + + assert hasattr(asyncio.Lock, 'patched') + assert hasattr(asyncio.locks.Lock, 'patched') + + lock = asyncio.Lock(loop=loop) + + ta = asyncio.Task(lock.acquire(), loop=loop) + test_utils.run_briefly(loop) + assert lock.locked() + + tb = asyncio.Task(lock.acquire(), loop=loop) + test_utils.run_briefly(loop) + assert len(lock._waiters) == 1 + + # Create a second waiter, wake up the first, and cancel it. + # Without the fix, the second was not woken up. + tc = asyncio.Task(lock.acquire(), loop=loop) + lock.release() + tb.cancel() + test_utils.run_briefly(loop) + + # tc waiter acquired lock + assert lock.locked() + assert ta.done() + assert tb.cancelled() + + loop.close() + + def test_patch_all(): - with mock.patch('asyncio_monkey.patch_get_event_loop') as mocked_patch_get_event_loop, mock.patch('asyncio_monkey.patch_log_destroy_pending') as mocked_patch_log_destroy_pending: # noqa + with mock.patch('asyncio_monkey.patch_get_event_loop') as mocked_patch_get_event_loop, \ + mock.patch('asyncio_monkey.patch_log_destroy_pending') as mocked_patch_log_destroy_pending, \ + mock.patch('asyncio_monkey.patch_lock') as mocked_patch_lock: # noqa patch_all() assert mocked_patch_get_event_loop.called_once() assert mocked_patch_log_destroy_pending.called_once() + assert mocked_patch_lock.called_once()