Skip to content

Commit

Permalink
Add keepalive to the threading implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Jan 22, 2025
1 parent fc7b151 commit 8f12d8f
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 31 deletions.
3 changes: 2 additions & 1 deletion docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ notice.
New features
............

* Added latency measurement to the :mod:`threading` implementation.
* Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the
:mod:`threading` implementation.

.. _14.2:

Expand Down
4 changes: 2 additions & 2 deletions docs/reference/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ Both sides
+------------------------------------+--------+--------+--------+--------+
| Send a pong |||||
+------------------------------------+--------+--------+--------+--------+
| Keepalive || |||
| Keepalive || |||
+------------------------------------+--------+--------+--------+--------+
| Heartbeat || |||
| Heartbeat || |||
+------------------------------------+--------+--------+--------+--------+
| Measure latency |||||
+------------------------------------+--------+--------+--------+--------+
Expand Down
5 changes: 0 additions & 5 deletions docs/topics/keepalive.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
Keepalive and latency
=====================

.. admonition:: This guide applies only to the :mod:`asyncio` implementation.
:class: tip

The :mod:`threading` implementation doesn't provide keepalive yet.

.. currentmodule:: websockets

Long-lived connections
Expand Down
17 changes: 11 additions & 6 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,8 +686,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]:
pong_waiter = self.loop.create_future()
# The event loop's default clock is time.monotonic(). Its resolution
# is a bit low on Windows (~16ms). This is improved in Python 3.13.
ping_timestamp = self.loop.time()
self.pong_waiters[data] = (pong_waiter, ping_timestamp)
self.pong_waiters[data] = (pong_waiter, self.loop.time())
self.protocol.send_ping(data)
return pong_waiter

Expand Down Expand Up @@ -792,13 +791,19 @@ async def keepalive(self) -> None:
latency = 0.0
try:
while True:
# If self.ping_timeout > latency > self.ping_interval, pings
# will be sent immediately after receiving pongs. The period
# will be longer than self.ping_interval.
# If self.ping_timeout > latency > self.ping_interval,
# pings will be sent immediately after receiving pongs.
# The period will be longer than self.ping_interval.
await asyncio.sleep(self.ping_interval - latency)

self.logger.debug("% sending keepalive ping")
# This cannot raise ConnectionClosed when the connection is
# closing because ping(), via send_context(), waits for the
# connection to be closed before raising ConnectionClosed.
# However, connection_lost() cancels keepalive_task before
# it gets a chance to resume excuting.
pong_waiter = await self.ping()
if self.debug:
self.logger.debug("% sent keepalive ping")

if self.ping_timeout is not None:
try:
Expand Down
17 changes: 15 additions & 2 deletions src/websockets/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class ClientConnection(Connection):
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
closed with any other code.
The ``close_timeout`` and ``max_queue`` arguments have the same meaning as
in :func:`connect`.
The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and
``max_queue`` arguments have the same meaning as in :func:`connect`.
Args:
socket: Socket connected to a WebSocket server.
Expand All @@ -54,6 +54,8 @@ def __init__(
socket: socket.socket,
protocol: ClientProtocol,
*,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = 10,
max_queue: int | None | tuple[int | None, int | None] = 16,
) -> None:
Expand All @@ -62,6 +64,8 @@ def __init__(
super().__init__(
socket,
protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_queue=max_queue,
)
Expand Down Expand Up @@ -136,6 +140,8 @@ def connect(
compression: str | None = "deflate",
# Timeouts
open_timeout: float | None = 10,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = 10,
# Limits
max_size: int | None = 2**20,
Expand Down Expand Up @@ -184,6 +190,10 @@ def connect(
:doc:`compression guide <../../topics/compression>` for details.
open_timeout: Timeout for opening the connection in seconds.
:obj:`None` disables the timeout.
ping_interval: Interval between keepalive pings in seconds.
:obj:`None` disables keepalive.
ping_timeout: Timeout for keepalive pings in seconds.
:obj:`None` disables timeouts.
close_timeout: Timeout for closing the connection in seconds.
:obj:`None` disables the timeout.
max_size: Maximum size of incoming messages in bytes.
Expand Down Expand Up @@ -296,6 +306,8 @@ def connect(
connection = create_connection(
sock,
protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_queue=max_queue,
)
Expand All @@ -315,6 +327,7 @@ def connect(
connection.recv_events_thread.join()
raise

connection.start_keepalive()
return connection


Expand Down
63 changes: 63 additions & 0 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ def __init__(
socket: socket.socket,
protocol: Protocol,
*,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = 10,
max_queue: int | None | tuple[int | None, int | None] = 16,
) -> None:
self.socket = socket
self.protocol = protocol
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.close_timeout = close_timeout
if isinstance(max_queue, int) or max_queue is None:
max_queue = (max_queue, None)
Expand Down Expand Up @@ -120,8 +124,15 @@ def __init__(
Latency is defined as the round-trip time of the connection. It is
measured by sending a Ping frame and waiting for a matching Pong frame.
Before the first measurement, :attr:`latency` is ``0``.
By default, websockets enables a :ref:`keepalive <keepalive>` mechanism
that sends Ping frames automatically at regular intervals. You can also
send Ping frames and measure latency with :meth:`ping`.
"""

# Thread that sends keepalive pings. None when ping_interval is None.
self.keepalive_thread: threading.Thread | None = None

# Public attributes

@property
Expand Down Expand Up @@ -700,6 +711,58 @@ def acknowledge_pending_pings(self) -> None:

self.pong_waiters.clear()

def keepalive(self) -> None:
"""
Send a Ping frame and wait for a Pong frame at regular intervals.
"""
assert self.ping_interval is not None
try:
while True:
# If self.ping_timeout > self.latency > self.ping_interval,
# pings will be sent immediately after receiving pongs.
# The period will be longer than self.ping_interval.
self.recv_events_thread.join(self.ping_interval - self.latency)
if not self.recv_events_thread.is_alive():
break

try:
pong_waiter = self.ping(ack_on_close=True)
except ConnectionClosed:
break
if self.debug:
self.logger.debug("% sent keepalive ping")

if self.ping_timeout is not None:
#
if pong_waiter.wait(self.ping_timeout):
if self.debug:
self.logger.debug("% received keepalive pong")
else:
if self.debug:
self.logger.debug("- timed out waiting for keepalive pong")
with self.send_context():
self.protocol.fail(
CloseCode.INTERNAL_ERROR,
"keepalive ping timeout",
)
break
except Exception:
self.logger.error("keepalive ping failed", exc_info=True)

def start_keepalive(self) -> None:
"""
Run :meth:`keepalive` in a thread, unless keepalive is disabled.
"""
if self.ping_interval is not None:
# This thread is marked as daemon like self.recv_events_thread.
self.keepalive_thread = threading.Thread(
target=self.keepalive,
daemon=True,
)
self.keepalive_thread.start()

def recv_events(self) -> None:
"""
Read incoming data from the socket and process events.
Expand Down
17 changes: 15 additions & 2 deletions src/websockets/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class ServerConnection(Connection):
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
closed with any other code.
The ``close_timeout`` and ``max_queue`` arguments have the same meaning as
in :func:`serve`.
The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and
``max_queue`` arguments have the same meaning as in :func:`serve`.
Args:
socket: Socket connected to a WebSocket client.
Expand All @@ -66,6 +66,8 @@ def __init__(
socket: socket.socket,
protocol: ServerProtocol,
*,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = 10,
max_queue: int | None | tuple[int | None, int | None] = 16,
) -> None:
Expand All @@ -74,6 +76,8 @@ def __init__(
super().__init__(
socket,
protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_queue=max_queue,
)
Expand Down Expand Up @@ -354,6 +358,8 @@ def serve(
compression: str | None = "deflate",
# Timeouts
open_timeout: float | None = 10,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = 10,
# Limits
max_size: int | None = 2**20,
Expand Down Expand Up @@ -434,6 +440,10 @@ def handler(websocket):
:doc:`compression guide <../../topics/compression>` for details.
open_timeout: Timeout for opening connections in seconds.
:obj:`None` disables the timeout.
ping_interval: Interval between keepalive pings in seconds.
:obj:`None` disables keepalive.
ping_timeout: Timeout for keepalive pings in seconds.
:obj:`None` disables timeouts.
close_timeout: Timeout for closing connections in seconds.
:obj:`None` disables the timeout.
max_size: Maximum size of incoming messages in bytes.
Expand Down Expand Up @@ -563,6 +573,8 @@ def protocol_select_subprotocol(
connection = create_connection(
sock,
protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_queue=max_queue,
)
Expand Down Expand Up @@ -590,6 +602,7 @@ def protocol_select_subprotocol(

assert connection.protocol.state is OPEN
try:
connection.start_keepalive()
handler(connection)
except Exception:
connection.logger.error("connection handler failed", exc_info=True)
Expand Down
26 changes: 13 additions & 13 deletions tests/asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ async def test_keepalive_times_out(self, getrandbits):
self.connection.start_keepalive()
# 4 ms: keepalive() sends a ping frame.
await asyncio.sleep(4 * MS)
# Exiting the context manager sleeps for MS.
# Exiting the context manager sleeps for 1 ms.
# 4.x ms: a pong frame is dropped.
# 6 ms: no pong frame is received; the connection is closed.
await asyncio.sleep(2 * MS)
Expand All @@ -1026,33 +1026,33 @@ async def test_keepalive_ignores_timeout(self, getrandbits):
getrandbits.return_value = 1918987876
self.connection.start_keepalive()
# 4 ms: keepalive() sends a ping frame.
await asyncio.sleep(4 * MS)
# Exiting the context manager sleeps for MS.
# 4.x ms: a pong frame is dropped.
await asyncio.sleep(4 * MS)
# Exiting the context manager sleeps for 1 ms.
# 6 ms: no pong frame is received; the connection remains open.
await asyncio.sleep(2 * MS)
# 7 ms: check that the connection is still open.
self.assertEqual(self.connection.state, State.OPEN)

async def test_keepalive_terminates_while_sleeping(self):
"""keepalive task terminates while waiting to send a ping."""
self.connection.ping_interval = 2 * MS
self.connection.ping_interval = 3 * MS
self.connection.start_keepalive()
await asyncio.sleep(MS)
await self.connection.close()
self.assertTrue(self.connection.keepalive_task.done())

async def test_keepalive_terminates_while_waiting_for_pong(self):
"""keepalive task terminates while waiting to receive a pong."""
self.connection.ping_interval = 2 * MS
self.connection.ping_timeout = 2 * MS
self.connection.ping_interval = MS
self.connection.ping_timeout = 3 * MS
async with self.drop_frames_rcvd():
self.connection.start_keepalive()
# 2 ms: keepalive() sends a ping frame.
await asyncio.sleep(2 * MS)
# Exiting the context manager sleeps for MS.
# 2.x ms: a pong frame is dropped.
# 3 ms: close the connection before ping_timeout elapses.
# 1 ms: keepalive() sends a ping frame.
# 1.x ms: a pong frame is dropped.
await asyncio.sleep(MS)
# Exiting the context manager sleeps for 1 ms.
# 2 ms: close the connection before ping_timeout elapses.
await self.connection.close()
self.assertTrue(self.connection.keepalive_task.done())

Expand All @@ -1062,9 +1062,9 @@ async def test_keepalive_reports_errors(self):
async with self.drop_frames_rcvd():
self.connection.start_keepalive()
# 2 ms: keepalive() sends a ping frame.
await asyncio.sleep(2 * MS)
# Exiting the context manager sleeps for MS.
# 2.x ms: a pong frame is dropped.
await asyncio.sleep(2 * MS)
# Exiting the context manager sleeps for 1 ms.
# 3 ms: inject a fault: raise an exception in the pending pong waiter.
pong_waiter = next(iter(self.connection.pong_waiters.values()))[0]
with self.assertLogs("websockets", logging.ERROR) as logs:
Expand Down
15 changes: 15 additions & 0 deletions tests/sync/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ def test_disable_compression(self):
with connect(get_uri(server), compression=None) as client:
self.assertEqual(client.protocol.extensions, [])

def test_keepalive_is_enabled(self):
"""Client enables keepalive and measures latency by default."""
with run_server() as server:
with connect(get_uri(server), ping_interval=MS) as client:
self.assertEqual(client.latency, 0)
time.sleep(2 * MS)
self.assertGreater(client.latency, 0)

def test_disable_keepalive(self):
"""Client disables keepalive."""
with run_server() as server:
with connect(get_uri(server), ping_interval=None) as client:
time.sleep(2 * MS)
self.assertEqual(client.latency, 0)

def test_logger(self):
"""Client accepts a logger argument."""
logger = logging.getLogger("test")
Expand Down
Loading

0 comments on commit 8f12d8f

Please sign in to comment.