Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix close race that prevented the close code from reaching the client #8680

Merged
merged 14 commits into from
Aug 12, 2024
1 change: 1 addition & 0 deletions CHANGES/8680.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a race closing the server-side WebSocket where the close code would not reach the client. -- by :user:`bdraco`.
28 changes: 12 additions & 16 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,23 +431,10 @@ async def close(
if self._writer is None:
raise RuntimeError("Call .prepare() first")

self._cancel_heartbeat()
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
reader = self._reader
assert reader is not None

# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting and not self._closed:
if not self._close_wait:
assert self._loop is not None
self._close_wait = self._loop.create_future()
reader.feed_data(WS_CLOSING_MESSAGE)
await self._close_wait

if self._closed:
return False

self._set_closed()

try:
await self._writer.close(code, message)
writer = self._payload_writer
Expand All @@ -462,12 +449,21 @@ async def close(
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True

reader = self._reader
assert reader is not None
# we need to break `receive()` cycle before we can call
# `reader.read()` as `close()` may be called from different task
if self._waiting:
assert self._loop is not None
assert self._close_wait is None
self._close_wait = self._loop.create_future()
reader.feed_data(WS_CLOSING_MESSAGE)
await self._close_wait

if self._closing:
self._close_transport()
return True

reader = self._reader
assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
Expand Down
60 changes: 60 additions & 0 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import asyncio
import contextlib
import sys
import weakref
from typing import Any, Optional

import pytest

import aiohttp
from aiohttp import WSServerHandshakeError, web
from aiohttp.http import WSCloseCode, WSMsgType
from aiohttp.pytest_plugin import AiohttpClient


async def test_websocket_can_prepare(loop: Any, aiohttp_client: Any) -> None:
Expand Down Expand Up @@ -1019,3 +1021,61 @@ async def handler(request):
await ws.close(code=WSCloseCode.OK, message="exit message")

await closed


async def test_websocket_shutdown(aiohttp_client: AiohttpClient) -> None:
"""Test that the client websocket gets the close message when the server is shutting down."""
url = "/ws"
app = web.Application()
websockets = web.AppKey("websockets", weakref.WeakSet)
app[websockets] = weakref.WeakSet()

# need for send signal shutdown server
shutdown_websockets = web.AppKey("shutdown_websockets", weakref.WeakSet)
app[shutdown_websockets] = weakref.WeakSet()

async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
websocket = web.WebSocketResponse()
await websocket.prepare(request)
request.app[websockets].add(websocket)
request.app[shutdown_websockets].add(websocket)

try:
async for message in websocket:
await websocket.send_json({"ok": True, "message": message.json()})
finally:
request.app[websockets].discard(websocket)

return websocket

async def on_shutdown(app: web.Application) -> None:
while app[shutdown_websockets]:
websocket = app[shutdown_websockets].pop()
await websocket.close(
code=aiohttp.WSCloseCode.GOING_AWAY,
message="Server shutdown",
)

app.router.add_get(url, websocket_handler)
app.on_shutdown.append(on_shutdown)

client = await aiohttp_client(app)

websocket = await client.ws_connect(url)

message = {"message": "hi"}
await websocket.send_json(message)
reply = await websocket.receive_json()
assert reply == {"ok": True, "message": message}

await app.shutdown()

assert websocket.closed is False

reply = await websocket.receive()

assert reply.type is aiohttp.http.WSMsgType.CLOSE
assert reply.data == aiohttp.WSCloseCode.GOING_AWAY
assert reply.extra == "Server shutdown"

assert websocket.closed is True
Loading