From 48b15583305e692ce997ec6f5a6a2f88f23ace71 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 26 Nov 2023 14:33:50 +0000 Subject: [PATCH] Fix regression with connection upgrade (#7879) Fixes #7867. --- CHANGES/7879.bugfix | 1 + aiohttp/client_reqrep.py | 19 ++++++++----------- aiohttp/connector.py | 4 ++++ tests/test_client_functional.py | 19 +++++++++++++++++++ 4 files changed, 32 insertions(+), 11 deletions(-) create mode 100644 CHANGES/7879.bugfix diff --git a/CHANGES/7879.bugfix b/CHANGES/7879.bugfix new file mode 100644 index 00000000000..08baf85be42 --- /dev/null +++ b/CHANGES/7879.bugfix @@ -0,0 +1 @@ +Fixed a regression where connection may get closed during upgrade. -- by :user:`Dreamsorcerer` diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 21858fc345a..2185aa85738 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -940,19 +940,14 @@ def _response_eof(self) -> None: if self._closed: return - if self._connection is not None: - # websocket, protocol could be None because - # connection could be detached - if ( - self._connection.protocol is not None - and self._connection.protocol.upgraded - ): - return - - self._release_connection() + # protocol could be None because connection could be detached + protocol = self._connection and self._connection.protocol + if protocol is not None and protocol.upgraded: + return self._closed = True self._cleanup_writer() + self._release_connection() @property def closed(self) -> bool: @@ -1048,7 +1043,9 @@ async def read(self) -> bytes: elif self._released: # Response explicitly released raise ClientConnectionError("Connection closed") - await self._wait_released() # Underlying connection released + protocol = self._connection and self._connection.protocol + if protocol is None or not protocol.upgraded: + await self._wait_released() # Underlying connection released return self._body def get_encoding(self) -> str: diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 2460ca46705..295882f65ca 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -107,6 +107,10 @@ def __del__(self, _warnings: Any = warnings) -> None: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) + def __bool__(self) -> Literal[True]: + """Force subclasses to not be falsy, to make checks simpler.""" + return True + @property def transport(self) -> Optional[asyncio.Transport]: if self._protocol is None: diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 4f0d594cb1f..66af22b4494 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -174,6 +174,25 @@ async def handler(request): assert 1 == len(client._session.connector._conns) +async def test_upgrade_connection_not_released_after_read(aiohttp_client: Any) -> None: + async def handler(request: web.Request) -> web.Response: + body = await request.read() + assert b"" == body + return web.Response( + status=101, headers={"Connection": "Upgrade", "Upgrade": "tcp"} + ) + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + + resp = await client.get("/") + await resp.read() + assert resp.connection is not None + assert not resp.closed + + async def test_keepalive_server_force_close_connection(aiohttp_client: Any) -> None: async def handler(request): body = await request.read()