From 9a70e5bfb3b94ea04c2b94559b56f13e09501cf6 Mon Sep 17 00:00:00 2001 From: kaiix Date: Thu, 27 Jun 2024 18:24:46 +0800 Subject: [PATCH] fix: websocket provider receives nothing forever when the peer is closed or gone --- .../module_testing/module_testing_utils.py | 21 +++++++++++-------- web3/providers/websocket/websocket_v2.py | 5 ++--- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/web3/_utils/module_testing/module_testing_utils.py b/web3/_utils/module_testing/module_testing_utils.py index 5147185a07..0cac3c6fec 100644 --- a/web3/_utils/module_testing/module_testing_utils.py +++ b/web3/_utils/module_testing/module_testing_utils.py @@ -1,6 +1,3 @@ -from collections import ( - deque, -) import pytest import time from typing import ( @@ -9,6 +6,7 @@ Collection, Dict, Generator, + Optional, Sequence, Union, ) @@ -35,6 +33,7 @@ ) from web3._utils.request import ( async_cache_and_return_session, + asyncio, cache_and_return_session, ) from web3.types import ( @@ -188,9 +187,13 @@ class WebsocketMessageStreamMock: closed: bool = False def __init__( - self, messages: Collection[bytes] = None, raise_exception: Exception = None + self, + messages: Optional[Collection[bytes]] = None, + raise_exception: Optional[Exception] = None, ) -> None: - self.messages = deque(messages) if messages else deque() + self.queue = asyncio.Queue[bytes]() + for msg in messages or []: + self.queue.put_nowait(msg) self.raise_exception = raise_exception def __await__(self) -> Generator[Any, Any, "Self"]: @@ -203,13 +206,13 @@ def __aiter__(self) -> "Self": return self async def __anext__(self) -> bytes: + return await self.recv() + + async def recv(self) -> bytes: if self.raise_exception: raise self.raise_exception - elif len(self.messages) == 0: - raise StopAsyncIteration - - return self.messages.popleft() + return await self.queue.get() @staticmethod async def pong() -> Literal[False]: diff --git a/web3/providers/websocket/websocket_v2.py b/web3/providers/websocket/websocket_v2.py index b7d42ba595..d2dd28b0ca 100644 --- a/web3/providers/websocket/websocket_v2.py +++ b/web3/providers/websocket/websocket_v2.py @@ -136,9 +136,8 @@ async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: return response async def _provider_specific_message_listener(self) -> None: - async for raw_message in self._ws: - await asyncio.sleep(0) - + while True: + raw_message = await self._ws.recv() response = json.loads(raw_message) subscription = response.get("method") == "eth_subscription" await self._request_processor.cache_raw_response(