From a0cc560f7fffd79d2416ce232a3d9c4411328650 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 13 Jul 2024 22:30:18 -0500 Subject: [PATCH 01/11] Avoid creating a future on every websocket receive Only create a future on close if we need to wait for the receive to finish --- aiohttp/client_ws.py | 19 +++++++++++-------- aiohttp/web_ws.py | 20 ++++++++++++-------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 76a4e237a48..79e5de8e4e3 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -76,7 +76,8 @@ def __init__( self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb: Optional[asyncio.TimerHandle] = None self._loop = loop - self._waiting: Optional[asyncio.Future[bool]] = None + self._waiting: bool = False + self._close_wait: Optional[asyncio.Future[None]] = None self._exception: Optional[BaseException] = None self._compress = compress self._client_notakeover = client_notakeover @@ -195,10 +196,12 @@ async def send_json( async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: # we need to break `receive()` cycle first, # `close()` may be called from different task - if self._waiting is not None and not self._closing: + if self._waiting and not self._closing: + assert self._loop is not None + self._close_wait = self._loop.create_future() self._closing = True self._reader.feed_data(WS_CLOSING_MESSAGE) - await self._waiting + await self._close_wait if not self._closed: self._cancel_heartbeat() @@ -242,7 +245,7 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo async def receive(self, timeout: Optional[float] = None) -> WSMessage: while True: - if self._waiting is not None: + if self._waiting: raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: @@ -252,7 +255,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: return WS_CLOSED_MESSAGE try: - self._waiting = self._loop.create_future() + self._waiting = True try: async with async_timeout.timeout( timeout or self._timeout.ws_receive @@ -260,9 +263,9 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: msg = await self._reader.read() self._reset_heartbeat() finally: - waiter = self._waiting - self._waiting = None - set_result(waiter, True) + self._waiting = False + if close_wait := self._close_wait: + set_result(close_wait, None) except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = WSCloseCode.ABNORMAL_CLOSURE raise diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 392bf432f0b..52c0a44cc62 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -67,6 +67,7 @@ class WebSocketResponse(StreamResponse): "_close_code", "_loop", "_waiting", + "_close_wait", "_exception", "_timeout", "_receive_timeout", @@ -103,7 +104,8 @@ def __init__( self._conn_lost = 0 self._close_code: Optional[int] = None self._loop: Optional[asyncio.AbstractEventLoop] = None - self._waiting: Optional[asyncio.Future[bool]] = None + self._waiting: bool = False + self._close_wait: Optional[asyncio.Future[None]] = None self._exception: Optional[BaseException] = None self._timeout = timeout self._receive_timeout = receive_timeout @@ -398,9 +400,11 @@ async def close( # we need to break `receive()` cycle first, # `close()` may be called from different task - if self._waiting is not None and not self._closed: + if self._waiting and not self._closed: + assert self._loop is not None + self._close_wait = self._loop.create_future() reader.feed_data(WS_CLOSING_MESSAGE) - await self._waiting + await self._close_wait if self._closed: return False @@ -467,7 +471,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: loop = self._loop assert loop is not None while True: - if self._waiting is not None: + if self._waiting: raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: @@ -479,15 +483,15 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: return WS_CLOSING_MESSAGE try: - self._waiting = loop.create_future() + self._waiting = True try: async with async_timeout.timeout(timeout or self._receive_timeout): msg = await self._reader.read() self._reset_heartbeat() finally: - waiter = self._waiting - set_result(waiter, True) - self._waiting = None + self._waiting = False + if close_wait := self._close_wait: + set_result(close_wait, None) except asyncio.TimeoutError: raise except EofStream: From 235fb059a1facf73c19a33ec55d6345aa2dd7a6c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 07:23:35 -0500 Subject: [PATCH 02/11] safer --- aiohttp/client_ws.py | 5 +++-- aiohttp/web_ws.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 79e5de8e4e3..e3a457f0e94 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -198,8 +198,9 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo # `close()` may be called from different task if self._waiting and not self._closing: assert self._loop is not None - self._close_wait = self._loop.create_future() - self._closing = True + if not self._close_wait: + self._close_wait = self._loop.create_future() + self._closing = True self._reader.feed_data(WS_CLOSING_MESSAGE) await self._close_wait diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 52c0a44cc62..4566dff9603 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -401,8 +401,9 @@ async def close( # we need to break `receive()` cycle first, # `close()` may be called from different task if self._waiting and not self._closed: - assert self._loop is not None - self._close_wait = self._loop.create_future() + if not self._close_wait: + assert self._loop is not None + self._close_wait = self._loop.create_future() reader.feed_data(WS_CLOSING_MESSAGE) await self._close_wait From 0c7d65e8288530bd6c2ae623d958a17846da1816 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 07:26:33 -0500 Subject: [PATCH 03/11] safer --- aiohttp/client_ws.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index e3a457f0e94..d2014527cb2 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -197,10 +197,10 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo # we need to break `receive()` cycle first, # `close()` may be called from different task if self._waiting and not self._closing: - assert self._loop is not None if not self._close_wait: + assert self._loop is not None self._close_wait = self._loop.create_future() - self._closing = True + self._closing = True self._reader.feed_data(WS_CLOSING_MESSAGE) await self._close_wait From 4935501d237509c1518636e0186b4f3993548d24 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 10:37:12 -0500 Subject: [PATCH 04/11] not needed on client side since self._closing = True is always set synchronously --- aiohttp/client_ws.py | 5 ++-- tests/test_client_ws_functional.py | 32 ++++++++++++++++++++ tests/test_web_websocket_functional.py | 41 ++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 3 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index d2014527cb2..79e5de8e4e3 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -197,9 +197,8 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo # we need to break `receive()` cycle first, # `close()` may be called from different task if self._waiting and not self._closing: - if not self._close_wait: - assert self._loop is not None - self._close_wait = self._loop.create_future() + assert self._loop is not None + self._close_wait = self._loop.create_future() self._closing = True self._reader.feed_data(WS_CLOSING_MESSAGE) await self._close_wait diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index f2bddbed792..cc1998ef9e4 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -294,6 +294,38 @@ async def handler(request): assert msg.type == aiohttp.WSMsgType.CLOSED +async def test_concurrent_close_multiple_tasks(aiohttp_client: Any) -> None: + client_ws = None + + async def handler(request): + nonlocal client_ws + ws = web.WebSocketResponse() + await ws.prepare(request) + + await ws.receive_bytes() + await ws.send_str("test") + + msg = await ws.receive() + assert msg.type == aiohttp.WSMsgType.CLOSE + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + ws = client_ws = await client.ws_connect("/") + + await ws.send_bytes(b"ask") + + task1 = asyncio.create_task(ws.close()) + task2 = asyncio.create_task(ws.close()) + + msg = await ws.receive() + assert msg.type == aiohttp.WSMsgType.CLOSED + + await task1 + await task2 + + async def test_close_from_server(aiohttp_client: Any) -> None: loop = asyncio.get_event_loop() closed = loop.create_future() diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index faba41c3ea1..7d990294840 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -311,6 +311,47 @@ async def handler(request): assert msg.type == WSMsgType.CLOSED +async def test_concurrent_close_multiple_tasks(loop: Any, aiohttp_client: Any) -> None: + srv_ws = None + + async def handler(request): + nonlocal srv_ws + ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) + await ws.prepare(request) + + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + + await asyncio.sleep(0) + + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSED + + return ws + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) + + task1 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) + task2 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) + + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSE + + await task1 + await task2 + + await asyncio.sleep(0) + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSED + + async def test_close_op_code_from_client(loop: Any, aiohttp_client: Any) -> None: srv_ws: Optional[web.WebSocketResponse] = None From 4cd115859970793516ad01b788f3a433ec6e037a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 10:39:53 -0500 Subject: [PATCH 05/11] better coverage --- tests/test_client_ws_functional.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index cc1998ef9e4..d3570e89d2c 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -325,6 +325,9 @@ async def handler(request): await task1 await task2 + msg = await ws.receive() + assert msg.type == aiohttp.WSMsgType.CLOSED + async def test_close_from_server(aiohttp_client: Any) -> None: loop = asyncio.get_event_loop() From 0250e58f6f298c31d7c2828ef67cf76ea2016ebc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 10:52:15 -0500 Subject: [PATCH 06/11] remove unused code from copied code --- tests/test_client_ws_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index d3570e89d2c..fbb603f309d 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -312,7 +312,7 @@ async def handler(request): app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) - ws = client_ws = await client.ws_connect("/") + ws = await client.ws_connect("/") await ws.send_bytes(b"ask") From f9bbc88a5686fdf36a0f381fe2a51677eacea198 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 10:54:12 -0500 Subject: [PATCH 07/11] changelog --- CHANGES/8498.misc.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 CHANGES/8498.misc.rst diff --git a/CHANGES/8498.misc.rst b/CHANGES/8498.misc.rst new file mode 100644 index 00000000000..5fcf3efd884 --- /dev/null +++ b/CHANGES/8498.misc.rst @@ -0,0 +1 @@ +Avoid creating a future on every websocket receive -- by :user:`bdraco`. From c8b724644f4139cd9c8ecfa6ed33e1ded4e88c51 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 10:55:05 -0500 Subject: [PATCH 08/11] enums are singletons, use is --- tests/test_client_ws_functional.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index fbb603f309d..7bdce6c0d2f 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -276,7 +276,7 @@ async def handler(request): await client_ws.close() msg = await ws.receive() - assert msg.type == aiohttp.WSMsgType.CLOSE + assert msg.type is aiohttp.WSMsgType.CLOSE return ws app = web.Application() @@ -287,11 +287,11 @@ async def handler(request): await ws.send_bytes(b"ask") msg = await ws.receive() - assert msg.type == aiohttp.WSMsgType.CLOSING + assert msg.type is aiohttp.WSMsgType.CLOSING await asyncio.sleep(0.01) msg = await ws.receive() - assert msg.type == aiohttp.WSMsgType.CLOSED + assert msg.type is aiohttp.WSMsgType.CLOSED async def test_concurrent_close_multiple_tasks(aiohttp_client: Any) -> None: @@ -306,7 +306,7 @@ async def handler(request): await ws.send_str("test") msg = await ws.receive() - assert msg.type == aiohttp.WSMsgType.CLOSE + assert msg.type is aiohttp.WSMsgType.CLOSE return ws app = web.Application() @@ -320,13 +320,13 @@ async def handler(request): task2 = asyncio.create_task(ws.close()) msg = await ws.receive() - assert msg.type == aiohttp.WSMsgType.CLOSED + assert msg.type is aiohttp.WSMsgType.CLOSED await task1 await task2 msg = await ws.receive() - assert msg.type == aiohttp.WSMsgType.CLOSED + assert msg.type is aiohttp.WSMsgType.CLOSED async def test_close_from_server(aiohttp_client: Any) -> None: From de2f6bf2c7e3d37ab2ba9c14d2f9816f23684599 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 11:07:58 -0500 Subject: [PATCH 09/11] remove unused code --- tests/test_client_ws_functional.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 7bdce6c0d2f..2d65b317c6d 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -295,10 +295,7 @@ async def handler(request): async def test_concurrent_close_multiple_tasks(aiohttp_client: Any) -> None: - client_ws = None - async def handler(request): - nonlocal client_ws ws = web.WebSocketResponse() await ws.prepare(request) From ca846ab7c7058fa81a595285b1ac06680a2ddcaf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 11:10:06 -0500 Subject: [PATCH 10/11] Update aiohttp/client_ws.py Co-authored-by: Sam Bull --- aiohttp/client_ws.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 79e5de8e4e3..981b0efff7a 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -264,8 +264,8 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: self._reset_heartbeat() finally: self._waiting = False - if close_wait := self._close_wait: - set_result(close_wait, None) + if self._close_wait: + set_result(self._close_wait, None) except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = WSCloseCode.ABNORMAL_CLOSURE raise From fa9b82010748fd2ee9df9cb20258a9575b2d5250 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Jul 2024 11:10:32 -0500 Subject: [PATCH 11/11] Update aiohttp/web_ws.py --- aiohttp/web_ws.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 4566dff9603..52a737b2964 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -491,8 +491,8 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: self._reset_heartbeat() finally: self._waiting = False - if close_wait := self._close_wait: - set_result(close_wait, None) + if self._close_wait: + set_result(self._close_wait, None) except asyncio.TimeoutError: raise except EofStream: