Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #8641/0a88bab backport][3.10] Fix WebSocket ping tasks being prematurely garbage collected #8646

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/8641.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed WebSocket ping tasks being prematurely garbage collected -- by :user:`bdraco`.

There was a small risk that WebSocket ping tasks would be prematurely garbage collected because the event loop only holds a weak reference to the task. The garbage collection risk has been fixed by holding a strong reference to the task. Additionally, the task is now scheduled eagerly with Python 3.12+ to increase the chance it can be completed immediately and avoid having to hold any references to the task.
25 changes: 20 additions & 5 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: Optional[asyncio.Task[None]] = None

self._reset_heartbeat()

Expand All @@ -80,6 +81,9 @@ def _cancel_heartbeat(self) -> None:
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None

def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
Expand Down Expand Up @@ -118,11 +122,6 @@ def _send_heartbeat(self) -> None:
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
Expand All @@ -131,6 +130,22 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
else:
ping_task = loop.create_task(self._writer.ping())

if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
self._ping_task = None

def _pong_not_received(self) -> None:
if not self._closed:
self._set_closed()
Expand Down
25 changes: 20 additions & 5 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,16 @@ def __init__(
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._compress = compress
self._max_msg_size = max_msg_size
self._ping_task: Optional[asyncio.Task[None]] = None

def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None

def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
Expand Down Expand Up @@ -141,11 +145,6 @@ def _send_heartbeat(self) -> None:
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
Expand All @@ -154,6 +153,22 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
else:
ping_task = loop.create_task(self._writer.ping())

if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
self._ping_task = None

def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._set_closed()
Expand Down
50 changes: 48 additions & 2 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import sys
from typing import Any, NoReturn
from unittest import mock

import pytest

Expand Down Expand Up @@ -727,8 +728,53 @@ async def handler(request):
assert isinstance(msg.data, ServerTimeoutError)


async def test_send_recv_compress(aiohttp_client: Any) -> None:
async def handler(request):
async def test_close_websocket_while_ping_inflight(
aiohttp_client: AiohttpClient,
) -> None:
"""Test closing the websocket while a ping is in-flight."""
ping_received = False

async def handler(request: web.Request) -> NoReturn:
nonlocal ping_received
ws = web.WebSocketResponse(autoping=False)
await ws.prepare(request)
msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.BINARY
msg = await ws.receive()
ping_received = msg.type is aiohttp.WSMsgType.PING
await ws.receive()
assert False

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)
resp = await client.ws_connect("/", heartbeat=0.1)
await resp.send_bytes(b"ask")

cancelled = False
ping_stated = False

async def delayed_ping() -> None:
nonlocal cancelled, ping_stated
ping_stated = True
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
cancelled = True
raise

with mock.patch.object(resp._writer, "ping", delayed_ping):
await asyncio.sleep(0.1)

await resp.close()
await asyncio.sleep(0)
assert ping_stated is True
assert cancelled is True


async def test_send_recv_compress(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)

Expand Down
Loading