From 96c56568d73882841b82a6a155a1328c6193a32e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 23 Aug 2022 13:46:03 +0000 Subject: [PATCH 1/3] make can_read() destructive for simplicity, and rename the method. Remove timeout argument, always timeout immediately. --- redis/asyncio/connection.py | 47 ++++++++-------------- tests/test_asyncio/test_cluster.py | 8 ++-- tests/test_asyncio/test_connection_pool.py | 2 +- tests/test_asyncio/test_pubsub.py | 2 +- 4 files changed, 22 insertions(+), 37 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 64848f4f5f..01e7b19a5a 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -208,7 +208,7 @@ def on_disconnect(self): def on_connect(self, connection: "Connection"): raise NotImplementedError() - async def can_read(self, timeout: float) -> bool: + async def can_read_destructive(self) -> bool: raise NotImplementedError() async def read_response( @@ -286,9 +286,9 @@ async def _read_from_socket( return False raise ConnectionError(f"Error while reading from socket: {ex.args}") - async def can_read(self, timeout: float) -> bool: + async def can_read_destructive(self) -> bool: return bool(self.length) or await self._read_from_socket( - timeout=timeout, raise_on_timeout=False + timeout=0, raise_on_timeout=False ) async def read(self, length: int) -> bytes: @@ -386,8 +386,8 @@ def on_disconnect(self): self._buffer = None self.encoder = None - async def can_read(self, timeout: float): - return self._buffer and bool(await self._buffer.can_read(timeout)) + async def can_read_destructive(self): + return self._buffer and bool(await self._buffer.can_read_destructive()) async def read_response( self, disable_decoding: bool = False @@ -444,9 +444,7 @@ async def read_response( class HiredisParser(BaseParser): """Parser class for connections using Hiredis""" - __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout") - - _next_response: bool + __slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout") def __init__(self, socket_read_size: int): if not HIREDIS_AVAILABLE: @@ -466,23 +464,18 @@ def on_connect(self, connection: "Connection"): kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) - self._next_response = False self._socket_timeout = connection.socket_timeout def on_disconnect(self): self._stream = None self._reader = None - self._next_response = False - async def can_read(self, timeout: float): + async def can_read_destructive(self): if not self._stream or not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - if self._next_response is False: - self._next_response = self._reader.gets() - if self._next_response is False: - return await self.read_from_socket(timeout=timeout, raise_on_timeout=False) - return True + if self._reader.gets(): + return True + return await self.read_from_socket(timeout=0, raise_on_timeout=False) async def read_from_socket( self, @@ -523,12 +516,6 @@ async def read_response( self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - # _next_response might be cached from a can_read() call - if self._next_response is not False: - response = self._next_response - self._next_response = False - return response - response = self._reader.gets() while response is False: await self.read_from_socket() @@ -925,12 +912,10 @@ async def send_command(self, *args: Any, **kwargs: Any) -> None: self.pack_command(*args), check_health=kwargs.get("check_health", True) ) - async def can_read(self, timeout: float = 0): + async def can_read_destructive(self): """Poll the socket to see if there's data that can be read.""" - if not self.is_connected: - await self.connect() try: - return await self._parser.can_read(timeout) + return await self._parser.can_read_destructive() except OSError as e: await self.disconnect(nowait=True) raise ConnectionError( @@ -1498,12 +1483,12 @@ async def get_connection(self, command_name, *keys, **options): # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if await connection.can_read(): + if await connection.can_read_destructive(): raise ConnectionError("Connection has data") from None except ConnectionError: await connection.disconnect() await connection.connect() - if await connection.can_read(): + if await connection.can_read_destructive(): raise ConnectionError("Connection not ready") from None except BaseException: # release the connection back to the pool so that we don't @@ -1699,12 +1684,12 @@ async def get_connection(self, command_name, *keys, **options): # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if await connection.can_read(): + if await connection.can_read_destructive(): raise ConnectionError("Connection has data") from None except ConnectionError: await connection.disconnect() await connection.connect() - if await connection.can_read(): + if await connection.can_read_destructive(): raise ConnectionError("Connection not ready") from None except BaseException: # release the connection back to the pool so that we don't leak it diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 88cfb1fcdf..f1bbe42267 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -433,7 +433,7 @@ async def test_refresh_using_specific_nodes( Connection, send_packed_command=mock.DEFAULT, connect=mock.DEFAULT, - can_read=mock.DEFAULT, + can_read_destructive=mock.DEFAULT, ) as mocks: # simulate 7006 as a failed node def execute_command_mock(self, *args, **options): @@ -473,7 +473,7 @@ def map_7007(self): execute_command.successful_calls = 0 execute_command.failed_calls = 0 initialize.side_effect = initialize_mock - mocks["can_read"].return_value = False + mocks["can_read_destructive"].return_value = False mocks["send_packed_command"].return_value = "MOCK_OK" mocks["connect"].return_value = None with mock.patch.object( @@ -514,7 +514,7 @@ async def test_reading_from_replicas_in_round_robin(self) -> None: send_command=mock.DEFAULT, read_response=mock.DEFAULT, _connect=mock.DEFAULT, - can_read=mock.DEFAULT, + can_read_destructive=mock.DEFAULT, on_connect=mock.DEFAULT, ) as mocks: with mock.patch.object( @@ -546,7 +546,7 @@ def execute_command_mock_third(self, *args, **options): mocks["send_command"].return_value = True mocks["read_response"].return_value = "OK" mocks["_connect"].return_value = True - mocks["can_read"].return_value = False + mocks["can_read_destructive"].return_value = False mocks["on_connect"].return_value = True # Create a cluster with reading from replications diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index d281608e51..35f23f44cc 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -103,7 +103,7 @@ async def connect(self): async def disconnect(self): pass - async def can_read(self, timeout: float = 0): + async def can_read_destructive(self, timeout: float = 0): return False diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 32268fe5e3..86584e4715 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -847,7 +847,7 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method): self.state = 1 with mock.patch.object(self.pubsub.connection, "_parser") as m: m.read_response.side_effect = socket.error - m.can_read.side_effect = socket.error + m.can_read_destructive.side_effect = socket.error # wait until task noticies the disconnect until we # undo the patch await self.cond.wait_for(lambda: self.state >= 2) From 31215259d43bf9103d225716b9b1d86fdca94cd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 23 Aug 2022 14:02:25 +0000 Subject: [PATCH 2/3] don't use can_read in pubsub --- redis/asyncio/client.py | 12 +++++++++--- redis/asyncio/connection.py | 4 ++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9c8caae0f9..6a01d5eb6b 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -24,6 +24,8 @@ cast, ) +import async_timeout + from redis.asyncio.connection import ( Connection, ConnectionPool, @@ -755,12 +757,16 @@ async def parse_response(self, block: bool = True, timeout: float = 0): await self.check_health() async def try_read(): + if not conn.is_connected: + await conn.connect() if not block: - if not await conn.can_read(timeout=timeout): + try: + async with async_timeout.timeout(timeout): + return await conn.read_response() + except asyncio.TimeoutError: return None else: - await conn.connect() - return await conn.read_response() + return await conn.read_response() response = await self._execute(conn, try_read) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 01e7b19a5a..53b41af7f8 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -942,6 +942,10 @@ async def read_response(self, disable_decoding: bool = False): raise ConnectionError( f"Error while reading from {self.host}:{self.port} : {e.args}" ) + except asyncio.CancelledError: + # need this check for 3.7, where CancelledError + # is subclass of Exception, not BaseException + raise except Exception: await self.disconnect(nowait=True) raise From 882881fc16f95dd2d24cd8c2bcfdf632227a74be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 19 Sep 2022 13:48:29 +0000 Subject: [PATCH 3/3] connection.connect() now has its own retry, don't need it inside a retry loop --- redis/asyncio/client.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 6a01d5eb6b..c13054b227 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -756,19 +756,21 @@ async def parse_response(self, block: bool = True, timeout: float = 0): await self.check_health() - async def try_read(): - if not conn.is_connected: - await conn.connect() - if not block: + if not conn.is_connected: + await conn.connect() + + if not block: + + async def read_with_timeout(): try: async with async_timeout.timeout(timeout): return await conn.read_response() except asyncio.TimeoutError: return None - else: - return await conn.read_response() - response = await self._execute(conn, try_read) + response = await self._execute(conn, read_with_timeout) + else: + response = await self._execute(conn, conn.read_response) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it