diff --git a/CHANGES b/CHANGES index b0744c6038..107fa5e5b5 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Revert #2104, add `disconnect_on_error` option to `read_response()` (#2506) * Allow data to drain from async PythonParser when reading during a disconnect() * Use asyncio.timeout() instead of async_timeout.timeout() for python >= 3.11 (#2602) * Add test and fix async HiredisParser when reading during a disconnect() (#2349) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 3e6626aedf..b7aca5793d 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -512,8 +512,6 @@ async def _try_send_command_parse_response(self, conn, *args, **options): await conn.disconnect(nowait=True) raise finally: - if self.single_connection_client: - self._single_conn_lock.release() if not self.connection: await self.connection_pool.release(conn) @@ -525,12 +523,20 @@ async def execute_command(self, *args, **options): command_name = args[0] conn = self.connection or await pool.get_connection(command_name, **options) + # locking / unlocking must be handled in same task, otherwise we will deadlock + # if parent task is cancelled. + # Parent task is preferable because it always gets cancelled, child task won´t + # get cancelled if parent is cancelled and the lock may be held indefinitely. if self.single_connection_client: await self._single_conn_lock.acquire() - return await asyncio.shield( - self._try_send_command_parse_response(conn, *args, **options) - ) + try: + return await asyncio.shield( + self._try_send_command_parse_response(conn, *args, **options) + ) + finally: + if self.single_connection_client: + self._single_conn_lock.release() async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options @@ -816,7 +822,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0): await conn.connect() read_timeout = None if block else timeout - response = await self._execute(conn, conn.read_response, timeout=read_timeout) + response = await self._execute( + conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False + ) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 58dcd66efb..f329d87dd3 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -803,7 +803,11 @@ async def send_packed_command( raise ConnectionError( f"Error {err_no} while writing to socket. {errmsg}." ) from e - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. await self.disconnect(nowait=True) raise @@ -826,7 +830,9 @@ async def can_read_destructive(self): async def read_response( self, disable_decoding: bool = False, + *, timeout: Optional[float] = None, + disconnect_on_error: bool = True, ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout @@ -842,22 +848,24 @@ async def read_response( ) except asyncio.TimeoutError: if timeout is not None: - # user requested timeout, return None + # user requested timeout, return None. Operation can be retried return None # it was a self.socket_timeout error. - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) raise ConnectionError( f"Error while reading from {self.host}:{self.port} : {e.args}" ) - except asyncio.CancelledError: - # need this check for 3.7, where CancelledError - # is subclass of Exception, not BaseException - raise - except Exception: - await self.disconnect(nowait=True) + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + await self.disconnect(nowait=True) raise if self.health_check_interval: diff --git a/redis/client.py b/redis/client.py index 1a9b96b83d..711a7d3085 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1526,7 +1526,7 @@ def try_read(): return None else: conn.connect() - return conn.read_response() + return conn.read_response(disconnect_on_error=False) response = self._execute(conn, try_read) diff --git a/redis/connection.py b/redis/connection.py index 162a4c3215..5097a3166d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -831,7 +831,11 @@ def send_packed_command(self, command, check_health=True): errno = e.args[0] errmsg = e.args[1] raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. self.disconnect() raise @@ -856,7 +860,9 @@ def can_read(self, timeout=0): self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") - def read_response(self, disable_decoding=False): + def read_response( + self, disable_decoding=False, *, disconnect_on_error: bool = True + ): """Read the response from a previously sent command""" host_error = self._host_error() @@ -864,15 +870,21 @@ def read_response(self, disable_decoding=False): try: response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise TimeoutError(f"Timeout reading from {host_error}") except OSError as e: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise ConnectionError( f"Error while reading from {host_error}" f" : {e.args}" ) - except Exception: - self.disconnect() + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + self.disconnect() raise if self.health_check_interval: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 7c6fd45ab9..82c051e373 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -1,9 +1,11 @@ """ Tests async overrides of commands from their mixins """ +import asyncio import binascii import datetime import re +import sys from string import ascii_letters import pytest @@ -18,6 +20,11 @@ skip_unless_arch_bits, ) +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + REDIS_6_VERSION = "5.9.0" @@ -2999,6 +3006,37 @@ async def test_module_list(self, r: redis.Redis): for x in await r.module_list(): assert isinstance(x, dict) + @pytest.mark.onlynoncluster + async def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + ready = asyncio.Event() + + async def helper(): + with pytest.raises(asyncio.CancelledError): + # blocking pop + ready.set() + await r.brpop(["nonexist"]) + # If the following is not done, further Timout operations will fail, + # because the timeout won't catch its Cancelled Error if the task + # has a pending cancel. Python documentation probably should reflect this. + if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + asyncio.current_task().uncancel() + # if all is well, we can continue. The following should not hang. + await r.set("status", "down") + + task = asyncio.create_task(helper()) + await ready.wait() + await asyncio.sleep(0.01) + # the task is now sleeping, lets send it an exception + task.cancel() + # If all is well, the task should finish right away, otherwise fail with Timeout + async with async_timeout(0.1): + await task + @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index e2d77fc1c3..e49dd42204 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -184,7 +184,7 @@ async def test_connection_parse_response_resume(r: redis.Redis): conn._parser._stream = MockStream(message, interrupt_every=2) for i in range(100): try: - response = await conn.read_response() + response = await conn.read_response(disconnect_on_error=False) break except MockStream.TestError: pass diff --git a/tests/test_commands.py b/tests/test_commands.py index 94249e9419..b08a554917 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,9 +1,12 @@ import binascii import datetime import re +import threading import time +from asyncio import CancelledError from string import ascii_letters from unittest import mock +from unittest.mock import patch import pytest @@ -4726,6 +4729,38 @@ def test_psync(self, r): res = r2.psync(r2.client_id(), 1) assert b"FULLRESYNC" in res + @pytest.mark.onlynoncluster + def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + + ok = False + + def helper(): + with pytest.raises(CancelledError): + # blocking pop + with patch.object( + r.connection._parser, "read_response", side_effect=CancelledError + ): + r.brpop(["nonexist"]) + # if all is well, we can continue. + r.set("status", "down") # should not hang + nonlocal ok + ok = True + + thread = threading.Thread(target=helper) + thread.start() + thread.join(0.1) + try: + assert not thread.is_alive() + assert ok + finally: + # disconnect here so that fixture cleanup can proceed + r.connection.disconnect() + @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_connection.py b/tests/test_connection.py index 25b4118b2c..75ba738047 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -160,7 +160,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): conn._parser._sock = mock_socket for i in range(100): try: - response = conn.read_response() + response = conn.read_response(disconnect_on_error=False) break except MockSocket.TestError: pass