From 5113cd3afe7cbf5f740d80db8449670c27c90101 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 7 Mar 2023 11:56:19 -1000 Subject: [PATCH] Use asyncio.timeout instead of asyncio.wait_for. asyncio.wait_for creates a task whereas asyncio.timeout doesn't. Fallback to a vendored version of async_timeout on Python < 3.11. async.timeout will become the underlying implementation for async.wait_for in Python 3.12: https://github.com/python/cpython/pull/98518 --- pyproject.toml | 1 + src/websockets/legacy/async_timeout.py | 225 +++++++++++++++++++++++++ src/websockets/legacy/compatibility.py | 9 + src/websockets/legacy/protocol.py | 34 ++-- 4 files changed, 246 insertions(+), 23 deletions(-) create mode 100644 src/websockets/legacy/async_timeout.py diff --git a/pyproject.toml b/pyproject.toml index 989b6b5e..0707c644 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ branch = true omit = [ # */websockets matches src/websockets and .tox/**/site-packages/websockets "*/websockets/__main__.py", + "*/websockets/legacy/async_timeout.py", "*/websockets/legacy/compatibility.py", "tests/maxi_cov.py", ] diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/legacy/async_timeout.py new file mode 100644 index 00000000..0a220892 --- /dev/null +++ b/src/websockets/legacy/async_timeout.py @@ -0,0 +1,225 @@ +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License, Version 2.0. + +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Optional, Type + + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + + +__version__ = "4.0.2" + + +__all__ = ("timeout", "timeout_at", "Timeout") + + +def timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + if delay is not None: + deadline = loop.time() + delay # type: Optional[float] + else: + deadline = None + return Timeout(deadline, loop) + + +def timeout_at(deadline: Optional[float]) -> "Timeout": + """Schedule the timeout at absolute time. + + deadline argument points on the time in the same clock system + as loop.time(). + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + + >>> async with timeout_at(loop.time() + 10): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + """ + loop = asyncio.get_running_loop() + return Timeout(deadline, loop) + + +class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + +@final +class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__( + self, deadline: Optional[float], loop: asyncio.AbstractEventLoop + ) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout() instead", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + + The delay can be negative. + + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError("cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + + deadline argument points on the time in the same clock system + as loop.time(). + + If new deadline is in the past the timeout is raised immediately. + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError("cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon(self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "asyncio.Task[None]") -> None: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index 303e203b..cb9b02c8 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -5,6 +5,9 @@ from typing import Any, Dict +__all__ = ["asyncio_timeout", "loop_if_py_lt_38"] + + if sys.version_info[:2] >= (3, 8): def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: @@ -22,3 +25,9 @@ def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: """ return {"loop": loop} + + +if sys.version_info[:2] >= (3, 11): + from asyncio import timeout as asyncio_timeout # noqa: F401 +else: + from .async_timeout import timeout as asyncio_timeout # noqa: F401 diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7f9ab2bd..78b59ee8 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -53,7 +53,7 @@ ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import loop_if_py_lt_38 +from .compatibility import asyncio_timeout, loop_if_py_lt_38 from .framing import Frame @@ -761,19 +761,16 @@ async def close(self, code: int = 1000, reason: str = "") -> None: """ try: - await asyncio.wait_for( - self.write_close_frame(Close(code, reason)), - self.close_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.close_timeout): + await self.write_close_frame(Close(code, reason)) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. # Fail the connection to shut down faster. self.fail_connection() - # If no close frame is received within the timeout, wait_for() cancels - # the data transfer task and raises TimeoutError. + # If no close frame is received within the timeout, asyncio_timeout() + # cancels the data transfer task and raises TimeoutError. # If close() is called multiple times concurrently and one of these # calls hits the timeout, the data transfer task will be canceled. @@ -782,11 +779,8 @@ async def close(self, code: int = 1000, reason: str = "") -> None: try: # If close() is canceled during the wait, self.transfer_data_task # is canceled before the timeout elapses. - await asyncio.wait_for( - self.transfer_data_task, - self.close_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.close_timeout): + await self.transfer_data_task except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1268,11 +1262,8 @@ async def keepalive_ping(self) -> None: if self.ping_timeout is not None: try: - await asyncio.wait_for( - pong_waiter, - self.ping_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.ping_timeout): + await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1384,11 +1375,8 @@ async def wait_for_connection_lost(self) -> bool: """ if not self.connection_lost_waiter.done(): try: - await asyncio.wait_for( - asyncio.shield(self.connection_lost_waiter), - self.close_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.close_timeout): + await asyncio.shield(self.connection_lost_waiter) except asyncio.TimeoutError: pass # Re-check self.connection_lost_waiter.done() synchronously because