Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid creating a future on every websocket receive #8498

Merged
merged 11 commits into from
Jul 14, 2024
1 change: 1 addition & 0 deletions CHANGES/8498.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid creating a future on every websocket receive -- by :user:`bdraco`.
19 changes: 11 additions & 8 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def __init__(
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._loop = loop
self._waiting: Optional[asyncio.Future[bool]] = None
self._waiting: bool = False
self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
Expand Down Expand Up @@ -195,10 +196,12 @@ async def send_json(
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closing:
if self._waiting and not self._closing:
assert self._loop is not None
self._close_wait = self._loop.create_future()
self._closing = True
self._reader.feed_data(WS_CLOSING_MESSAGE)
await self._waiting
await self._close_wait

if not self._closed:
self._cancel_heartbeat()
Expand Down Expand Up @@ -242,7 +245,7 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo

async def receive(self, timeout: Optional[float] = None) -> WSMessage:
while True:
if self._waiting is not None:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")

if self._closed:
Expand All @@ -252,17 +255,17 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
return WS_CLOSED_MESSAGE

try:
self._waiting = self._loop.create_future()
self._waiting = True
try:
async with async_timeout.timeout(
timeout or self._timeout.ws_receive
):
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
self._waiting = None
set_result(waiter, True)
self._waiting = False
if close_wait := self._close_wait:
set_result(close_wait, None)
bdraco marked this conversation as resolved.
Show resolved Hide resolved
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
Expand Down
21 changes: 13 additions & 8 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class WebSocketResponse(StreamResponse):
"_close_code",
"_loop",
"_waiting",
"_close_wait",
"_exception",
"_timeout",
"_receive_timeout",
Expand Down Expand Up @@ -103,7 +104,8 @@ def __init__(
self._conn_lost = 0
self._close_code: Optional[int] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._waiting: Optional[asyncio.Future[bool]] = None
self._waiting: bool = False
self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._timeout = timeout
self._receive_timeout = receive_timeout
Expand Down Expand Up @@ -398,9 +400,12 @@ async def close(

# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closed:
if self._waiting and not self._closed:
if not self._close_wait:
assert self._loop is not None
self._close_wait = self._loop.create_future()
reader.feed_data(WS_CLOSING_MESSAGE)
await self._waiting
await self._close_wait
Dismissed Show dismissed Hide dismissed

if self._closed:
return False
Expand Down Expand Up @@ -467,7 +472,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
loop = self._loop
assert loop is not None
while True:
if self._waiting is not None:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")

if self._closed:
Expand All @@ -479,15 +484,15 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
return WS_CLOSING_MESSAGE

try:
self._waiting = loop.create_future()
self._waiting = True
try:
async with async_timeout.timeout(timeout or self._receive_timeout):
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
set_result(waiter, True)
self._waiting = None
self._waiting = False
if close_wait := self._close_wait:
set_result(close_wait, None)
bdraco marked this conversation as resolved.
Show resolved Hide resolved
except asyncio.TimeoutError:
raise
except EofStream:
Expand Down
41 changes: 38 additions & 3 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@
await client_ws.close()

msg = await ws.receive()
assert msg.type == aiohttp.WSMsgType.CLOSE
assert msg.type is aiohttp.WSMsgType.CLOSE
return ws

app = web.Application()
Expand All @@ -287,11 +287,46 @@
await ws.send_bytes(b"ask")

msg = await ws.receive()
assert msg.type == aiohttp.WSMsgType.CLOSING
assert msg.type is aiohttp.WSMsgType.CLOSING

await asyncio.sleep(0.01)
msg = await ws.receive()
assert msg.type == aiohttp.WSMsgType.CLOSED
assert msg.type is aiohttp.WSMsgType.CLOSED


async def test_concurrent_close_multiple_tasks(aiohttp_client: Any) -> None:
client_ws = None

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable client_ws is not used.

async def handler(request):
nonlocal client_ws
ws = web.WebSocketResponse()
await ws.prepare(request)

await ws.receive_bytes()
await ws.send_str("test")

msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSE
return ws

app = web.Application()
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)
ws = await client.ws_connect("/")

await ws.send_bytes(b"ask")

task1 = asyncio.create_task(ws.close())
task2 = asyncio.create_task(ws.close())

msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSED

await task1
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
await task2
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed

msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSED


async def test_close_from_server(aiohttp_client: Any) -> None:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,47 @@ async def handler(request):
assert msg.type == WSMsgType.CLOSED


async def test_concurrent_close_multiple_tasks(loop: Any, aiohttp_client: Any) -> None:
srv_ws = None

async def handler(request):
nonlocal srv_ws
ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar"))
await ws.prepare(request)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING

await asyncio.sleep(0)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSED

return ws

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar"))

task1 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT))
task2 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT))

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE

await task1
Dismissed Show dismissed Hide dismissed
await task2
Dismissed Show dismissed Hide dismissed

await asyncio.sleep(0)
msg = await ws.receive()
assert msg.type == WSMsgType.CLOSED


async def test_close_op_code_from_client(loop: Any, aiohttp_client: Any) -> None:
srv_ws: Optional[web.WebSocketResponse] = None

Expand Down
Loading