From 4bdd8a7e73b45751a4cde02070518a9e7c4246b2 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Mon, 9 Oct 2023 10:07:50 -0700 Subject: [PATCH] Switch to Python 3.12-style `wait_for` (#1086) `wait_for` has been a mess with respect to cancellations consistently in `asyncio`. Hopefully the approach taken in Python 3.12 solves the issues, so adopt that instead of trying to "fix" `wait_for` with wrappers on older Pythons. Use `async_timeout` as a polyfill on pre-3.11 Python. Closes: #1056 Closes: #1052 Fixes: #955 --- asyncpg/_asyncio_compat.py | 87 +++++++++++++++++++++++++++++++++++ asyncpg/compat.py | 20 ++------ asyncpg/protocol/protocol.pyx | 6 +-- pyproject.toml | 3 ++ 4 files changed, 98 insertions(+), 18 deletions(-) create mode 100644 asyncpg/_asyncio_compat.py diff --git a/asyncpg/_asyncio_compat.py b/asyncpg/_asyncio_compat.py new file mode 100644 index 00000000..ad7dfd8c --- /dev/null +++ b/asyncpg/_asyncio_compat.py @@ -0,0 +1,87 @@ +# Backports from Python/Lib/asyncio for older Pythons +# +# Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved +# +# SPDX-License-Identifier: PSF-2.0 + + +import asyncio +import functools +import sys + +if sys.version_info < (3, 11): + from async_timeout import timeout as timeout_ctx +else: + from asyncio import timeout as timeout_ctx + + +async def wait_for(fut, timeout): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. When a timeout occurs, + it cancels the task and raises TimeoutError. To avoid the task + cancellation, wrap it in shield(). + + If the wait is cancelled, the task is also cancelled. + + If the task supresses the cancellation and returns a value instead, + that value is returned. + + This function is a coroutine. + """ + # The special case for timeout <= 0 is for the following case: + # + # async def test_waitfor(): + # func_started = False + # + # async def func(): + # nonlocal func_started + # func_started = True + # + # try: + # await asyncio.wait_for(func(), 0) + # except asyncio.TimeoutError: + # assert not func_started + # else: + # assert False + # + # asyncio.run(test_waitfor()) + + if timeout is not None and timeout <= 0: + fut = asyncio.ensure_future(fut) + + if fut.done(): + return fut.result() + + await _cancel_and_wait(fut) + try: + return fut.result() + except asyncio.CancelledError as exc: + raise TimeoutError from exc + + async with timeout_ctx(timeout): + return await fut + + +async def _cancel_and_wait(fut): + """Cancel the *fut* future or task and wait until it completes.""" + + loop = asyncio.get_running_loop() + waiter = loop.create_future() + cb = functools.partial(_release_waiter, waiter) + fut.add_done_callback(cb) + + try: + fut.cancel() + # We cannot wait on *fut* directly to make + # sure _cancel_and_wait itself is reliably cancellable. + await waiter + finally: + fut.remove_done_callback(cb) + + +def _release_waiter(waiter, *args): + if not waiter.done(): + waiter.set_result(None) diff --git a/asyncpg/compat.py b/asyncpg/compat.py index b9b13fa5..532c197a 100644 --- a/asyncpg/compat.py +++ b/asyncpg/compat.py @@ -5,10 +5,10 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 -import asyncio import pathlib import platform import typing +import sys SYSTEM = platform.uname().system @@ -49,17 +49,7 @@ async def wait_closed(stream): pass -# Workaround for https://bugs.python.org/issue37658 -async def wait_for(fut, timeout): - if timeout is None: - return await fut - - fut = asyncio.ensure_future(fut) - - try: - return await asyncio.wait_for(fut, timeout) - except asyncio.CancelledError: - if fut.done(): - return fut.result() - else: - raise +if sys.version_info < (3, 12): + from ._asyncio_compat import wait_for as wait_for # noqa: F401 +else: + from asyncio import wait_for as wait_for # noqa: F401 diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index f504d9d0..1f739cc2 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -249,7 +249,7 @@ cdef class BaseProtocol(CoreProtocol): while more: with timer: - await asyncio.wait_for( + await compat.wait_for( self.writing_allowed.wait(), timeout=timer.get_remaining_budget()) # On Windows the above event somehow won't allow context @@ -383,7 +383,7 @@ cdef class BaseProtocol(CoreProtocol): if buffer: try: with timer: - await asyncio.wait_for( + await compat.wait_for( sink(buffer), timeout=timer.get_remaining_budget()) except (Exception, asyncio.CancelledError) as ex: @@ -511,7 +511,7 @@ cdef class BaseProtocol(CoreProtocol): with timer: await self.writing_allowed.wait() with timer: - chunk = await asyncio.wait_for( + chunk = await compat.wait_for( iterator.__anext__(), timeout=timer.get_remaining_budget()) self._write_copy_data_msg(chunk) diff --git a/pyproject.toml b/pyproject.toml index 72812da1..ed2340a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,9 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Topic :: Database :: Front-Ends", ] +dependencies = [ + 'async_timeout>=4.0.3; python_version < "3.12.0"' +] [project.urls] github = "https://github.com/MagicStack/asyncpg"