Skip to content

Commit

Permalink
Fix regression with connection upgrade (#7879)
Browse files Browse the repository at this point in the history
Fixes #7867.

(cherry picked from commit 48b1558)
  • Loading branch information
Dreamsorcerer committed Nov 26, 2023
1 parent 946523d commit 5875b17
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGES/7879.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a regression where connection may get closed during upgrade. -- by :user:`Dreamsorcerer`
19 changes: 8 additions & 11 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,19 +1006,14 @@ def _response_eof(self) -> None:
if self._closed:
return

if self._connection is not None:
# websocket, protocol could be None because
# connection could be detached
if (
self._connection.protocol is not None
and self._connection.protocol.upgraded
):
return

self._release_connection()
# protocol could be None because connection could be detached
protocol = self._connection and self._connection.protocol
if protocol is not None and protocol.upgraded:
return

self._closed = True
self._cleanup_writer()
self._release_connection()

@property
def closed(self) -> bool:
Expand Down Expand Up @@ -1113,7 +1108,9 @@ async def read(self) -> bytes:
elif self._released: # Response explicitly released
raise ClientConnectionError("Connection closed")

await self._wait_released() # Underlying connection released
protocol = self._connection and self._connection.protocol
if protocol is None or not protocol.upgraded:
await self._wait_released() # Underlying connection released
return self._body # type: ignore[no-any-return]

def get_encoding(self) -> str:
Expand Down
4 changes: 4 additions & 0 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def __del__(self, _warnings: Any = warnings) -> None:
context["source_traceback"] = self._source_traceback
self._loop.call_exception_handler(context)

def __bool__(self) -> Literal[True]:
"""Force subclasses to not be falsy, to make checks simpler."""
return True

@property
def loop(self) -> asyncio.AbstractEventLoop:
warnings.warn(
Expand Down
19 changes: 19 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,25 @@ async def handler(request):
assert 1 == len(client._session.connector._conns)


async def test_upgrade_connection_not_released_after_read(aiohttp_client) -> None:
async def handler(request: web.Request) -> web.Response:
body = await request.read()
assert b"" == body
return web.Response(
status=101, headers={"Connection": "Upgrade", "Upgrade": "tcp"}
)

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)

resp = await client.get("/")
await resp.read()
assert resp.connection is not None
assert not resp.closed


async def test_keepalive_server_force_close_connection(aiohttp_client) -> None:
async def handler(request):
body = await request.read()
Expand Down

0 comments on commit 5875b17

Please sign in to comment.