From 1f1abbd0c245b7d29284f21a230ddf98387b7f2a Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Mon, 23 Sep 2024 01:39:31 +0100 Subject: [PATCH 1/2] Fix cancellations being swallowed (#9030) Co-authored-by: J. Nick Koston (cherry picked from commit 1a77ad933f07ab0e7ba0c16f7ca8f02fa8ab044e) --- CHANGES/9030.bugfix.rst | 1 + aiohttp/client_reqrep.py | 37 ++++++++++++++++++++++++++-------- aiohttp/web_protocol.py | 38 ++++++++++++++++++++++++++--------- tests/test_client_request.py | 19 +++++++++++++++++- tests/test_web_functional.py | 39 ++++++++++++++++++++++++++++++++++-- 5 files changed, 114 insertions(+), 20 deletions(-) create mode 100644 CHANGES/9030.bugfix.rst diff --git a/CHANGES/9030.bugfix.rst b/CHANGES/9030.bugfix.rst new file mode 100644 index 00000000000..2e9d48f5359 --- /dev/null +++ b/CHANGES/9030.bugfix.rst @@ -0,0 +1 @@ +Fixed (on Python 3.11+) some edge cases where a task cancellation may get incorrectly suppressed -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 10144f2a9c4..dfd572d928b 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -625,11 +625,8 @@ async def write_bytes( """Support coroutines that yields bytes objects.""" # 100 response if self._continue is not None: - try: - await writer.drain() - await self._continue - except asyncio.CancelledError: - return + await writer.drain() + await self._continue protocol = conn.protocol assert protocol is not None @@ -658,6 +655,7 @@ async def write_bytes( except asyncio.CancelledError: # Body hasn't been fully sent, so connection can't be reused. conn.close() + raise except Exception as underlying_exc: set_exception( protocol, @@ -764,8 +762,15 @@ async def send(self, conn: "Connection") -> "ClientResponse": async def close(self) -> None: if self._writer is not None: - with contextlib.suppress(asyncio.CancelledError): + try: await self._writer + except asyncio.CancelledError: + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise def terminate(self) -> None: if self._writer is not None: @@ -1113,7 +1118,15 @@ def _release_connection(self) -> None: async def _wait_released(self) -> None: if self._writer is not None: - await self._writer + try: + await self._writer + except asyncio.CancelledError: + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise self._release_connection() def _cleanup_writer(self) -> None: @@ -1129,7 +1142,15 @@ def _notify_content(self) -> None: async def wait_for_close(self) -> None: if self._writer is not None: - await self._writer + try: + await self._writer + except asyncio.CancelledError: + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise self.release() async def read(self) -> bytes: diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 8fa8535b93a..85eb70d5a0b 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -271,17 +271,32 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None: # down while the handler is still processing a request # to avoid creating a future for every request. self._handler_waiter = self._loop.create_future() - with suppress(asyncio.CancelledError, asyncio.TimeoutError): + try: async with ceil_timeout(timeout): await self._handler_waiter + except (asyncio.CancelledError, asyncio.TimeoutError): + self._handler_waiter = None + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise # Then cancel handler and wait - with suppress(asyncio.CancelledError, asyncio.TimeoutError): + try: async with ceil_timeout(timeout): if self._current_request is not None: self._current_request._cancel(asyncio.CancelledError()) if self._task_handler is not None and not self._task_handler.done(): - await self._task_handler + await asyncio.shield(self._task_handler) + except (asyncio.CancelledError, asyncio.TimeoutError): + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise # force-close non-idle handler if self._task_handler is not None: @@ -517,8 +532,6 @@ async def start(self) -> None: # wait for next request self._waiter = loop.create_future() await self._waiter - except asyncio.CancelledError: - break finally: self._waiter = None @@ -545,7 +558,7 @@ async def start(self) -> None: task = loop.create_task(coro) try: resp, reset = await task - except (asyncio.CancelledError, ConnectionError): + except ConnectionError: self.log_debug("Ignored premature client disconnection") break @@ -569,12 +582,19 @@ async def start(self) -> None: now = loop.time() end_t = now + lingering_time - with suppress(asyncio.TimeoutError, asyncio.CancelledError): + try: while not payload.is_eof() and now < end_t: async with ceil_timeout(end_t - now): # read and ignore await payload.readany() now = loop.time() + except (asyncio.CancelledError, asyncio.TimeoutError): + if ( + sys.version_info >= (3, 11) + and (t := asyncio.current_task()) + and t.cancelling() + ): + raise # if payload still uncompleted if not payload.is_eof() and not self._force_close: @@ -584,8 +604,8 @@ async def start(self) -> None: payload.set_exception(_PAYLOAD_ACCESS_ERROR) except asyncio.CancelledError: - self.log_debug("Ignored premature client disconnection ") - break + self.log_debug("Ignored premature client disconnection") + raise except Exception as exc: self.log_exception("Unhandled exception", exc_info=exc) self.force_close() diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 2d70ebdd4f2..f2eff019504 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -2,6 +2,7 @@ import hashlib import io import pathlib +import sys import urllib.parse import zlib from http.cookies import BaseCookie, Morsel, SimpleCookie @@ -1213,7 +1214,23 @@ async def test_oserror_on_write_bytes(loop, conn) -> None: await req.close() -async def test_terminate(loop, conn) -> None: +@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()") +async def test_cancel_close(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: + req = ClientRequest("get", URL("http://python.org"), loop=loop) + req._writer = asyncio.Future() # type: ignore[assignment] + + t = asyncio.create_task(req.close()) + + # Start waiting on _writer + await asyncio.sleep(0) + + t.cancel() + # Cancellation should not be suppressed. + with pytest.raises(asyncio.CancelledError): + await t + + +async def test_terminate(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: req = ClientRequest("get", URL("http://python.org"), loop=loop) async def _mock_write_bytes(*args, **kwargs): diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 5b2e5fe9353..b3d1205855c 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -3,6 +3,7 @@ import json import pathlib import socket +import sys import zlib from typing import Any, NoReturn, Optional from unittest import mock @@ -187,8 +188,42 @@ async def handler(request): await resp.release() -async def test_post_form(aiohttp_client) -> None: - async def handler(request): +@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()") +async def test_cancel_shutdown(aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.Response: + t = asyncio.create_task(request.protocol.shutdown()) + # Ensure it's started waiting + await asyncio.sleep(0) + + t.cancel() + # Cancellation should not be suppressed + with pytest.raises(asyncio.CancelledError): + await t + + # Repeat for second waiter in shutdown() + with mock.patch.object(request.protocol, "_request_in_progress", False): + with mock.patch.object(request.protocol, "_current_request", None): + t = asyncio.create_task(request.protocol.shutdown()) + await asyncio.sleep(0) + + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + return web.Response(body=b"OK") + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + async with client.get("/") as resp: + assert resp.status == 200 + txt = await resp.text() + assert txt == "OK" + + +async def test_post_form(aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.Response: data = await request.post() assert {"a": "1", "b": "2", "c": ""} == data return web.Response(body=b"OK") From 7f59f4eb11a43a0daaaf3267752c8c262e26d04f Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Mon, 23 Sep 2024 01:53:32 +0100 Subject: [PATCH 2/2] Update test_web_functional.py --- tests/test_web_functional.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index b3d1205855c..e46a23c5857 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -23,6 +23,7 @@ web, ) from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING +from aiohttp.pytest_plugin import AiohttpClient from aiohttp.test_utils import make_mocked_coro from aiohttp.typedefs import Handler from aiohttp.web_protocol import RequestHandler