Skip to content

Commit

Permalink
Fix exceptions from WebSocket ping task not being consumed (#8685)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Aug 17, 2024
1 parent 490fca6 commit e7c02ca
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGES/8685.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed unconsumed exceptions raised by the WebSocket heartbeat -- by :user:`bdraco`.

If the heartbeat ping raised an exception, it would not be consumed and would be logged as an warning.
25 changes: 16 additions & 9 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,28 @@ def _send_heartbeat(self) -> None:
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None

def _pong_not_received(self) -> None:
if not self._closed:
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = ServerTimeoutError()
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(
WSMessage(WSMsgType.ERROR, self._exception, None)
)
self._handle_ping_pong_exception(ServerTimeoutError())

def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))

def _set_closed(self) -> None:
"""Set the connection to closed.
Expand Down
18 changes: 15 additions & 3 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,28 @@ def _send_heartbeat(self) -> None:
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None

def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._set_closed()
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = asyncio.TimeoutError()
self._handle_ping_pong_exception(asyncio.TimeoutError())

def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = exc
if self._waiting and not self._closing and self._reader is not None:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))

def _set_closed(self) -> None:
"""Set the connection to closed.
Expand Down
30 changes: 30 additions & 0 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,36 @@ async def handler(request: web.Request) -> NoReturn:
assert ping_received


async def test_heartbeat_connection_closed(aiohttp_client: AiohttpClient) -> None:
"""Test that the connection is closed while ping is in progress."""

async def handler(request: web.Request) -> NoReturn:
ws = web.WebSocketResponse(autoping=False)
await ws.prepare(request)
await ws.receive()
assert False

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)
resp = await client.ws_connect("/", heartbeat=0.1)
ping_count = 0
# We patch write here to simulate a connection reset error
# since if we closed the connection normally, the client would
# would cancel the heartbeat task and we wouldn't get a ping
assert resp._conn is not None
with mock.patch.object(
resp._conn.transport, "write", side_effect=ConnectionResetError
), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping:
await resp.receive()
ping_count = ping.call_count
# Connection should be closed roughly after 1.5x heartbeat.
await asyncio.sleep(0.2)
assert ping_count == 1
assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE


async def test_heartbeat_no_pong(aiohttp_client: AiohttpClient) -> None:
"""Test that the connection is closed if no pong is received without sending messages."""
ping_received = False
Expand Down
73 changes: 72 additions & 1 deletion tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import contextlib
import sys
import weakref
from typing import Any, Optional
from typing import Any, NoReturn, Optional
from unittest import mock

import pytest

Expand Down Expand Up @@ -717,6 +718,76 @@ async def handler(request):
await ws.close()


async def test_heartbeat_connection_closed(
loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient
) -> None:
"""Test that the connection is closed while ping is in progress."""
ping_count = 0

async def handler(request: web.Request) -> NoReturn:
nonlocal ping_count
ws_server = web.WebSocketResponse(heartbeat=0.05)
await ws_server.prepare(request)
# We patch write here to simulate a connection reset error
# since if we closed the connection normally, the server would
# would cancel the heartbeat task and we wouldn't get a ping
with mock.patch.object(
ws_server._req.transport, "write", side_effect=ConnectionResetError
), mock.patch.object(
ws_server._writer, "ping", wraps=ws_server._writer.ping
) as ping:
try:
await ws_server.receive()
finally:
ping_count = ping.call_count
assert False

app = web.Application()
app.router.add_get("/", handler)

client = await aiohttp_client(app)
ws = await client.ws_connect("/", autoping=False)
msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSED
assert msg.extra is None
assert ws.close_code == WSCloseCode.ABNORMAL_CLOSURE
assert ping_count == 1
await ws.close()


async def test_heartbeat_failure_ends_receive(
loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient
) -> None:
"""Test that no heartbeat response to the server ends the receive call."""
ws_server_close_code = None
ws_server_exception = None

async def handler(request: web.Request) -> NoReturn:
nonlocal ws_server_close_code, ws_server_exception
ws_server = web.WebSocketResponse(heartbeat=0.05)
await ws_server.prepare(request)
try:
await ws_server.receive()
finally:
ws_server_close_code = ws_server.close_code
ws_server_exception = ws_server.exception()
assert False

app = web.Application()
app.router.add_get("/", handler)

client = await aiohttp_client(app)
ws = await client.ws_connect("/", autoping=False)
msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.PING
msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSED
assert ws.close_code == WSCloseCode.ABNORMAL_CLOSURE
assert ws_server_close_code == WSCloseCode.ABNORMAL_CLOSURE
assert isinstance(ws_server_exception, asyncio.TimeoutError)
await ws.close()


async def test_heartbeat_no_pong_send_many_messages(
loop: Any, aiohttp_client: Any
) -> None:
Expand Down

0 comments on commit e7c02ca

Please sign in to comment.