diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 1845b7252f..abc141e8c9 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -651,6 +651,10 @@ async def execute_command(self, *args, **options): finally: if not self.connection: await pool.release(conn) + # Do additional cleanup if this is part of a SCAN ITER family command. + # It's possible that this is just a pure SCAN family command though. + if "SCAN" in command_name.upper(): + pool.cleanup(iter_req_id=options.get("iter_req_id", None)) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2ac6637986..35ef22272a 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1057,6 +1057,14 @@ class ConnectionPool: ``connection_class``. """ + def cleanup(self, **options): + """ + Additional cleanup operations that the connection pool might need to do. + See SentinelManagedConnection for an example cleanup operation that + might need to be done. + """ + pass + @classmethod def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: """ @@ -1118,7 +1126,7 @@ def __init__( self.connection_kwargs = connection_kwargs self.max_connections = max_connections - self._available_connections: List[AbstractConnection] = [] + self._available_connections = self.reset_available_connections() self._in_use_connections: Set[AbstractConnection] = set() self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) @@ -1129,9 +1137,12 @@ def __repr__(self): ) def reset(self): - self._available_connections = [] + self._available_connections = self.reset_available_connections() self._in_use_connections = weakref.WeakSet() + def reset_available_connections(self): + return [] + def can_get_connection(self) -> bool: """Return True if a connection can be retrieved from the pool.""" return ( @@ -1324,3 +1335,6 @@ async def release(self, connection: AbstractConnection): async with self._condition: await super().release(connection) self._condition.notify() + + def cleanup(self, **options): + pass diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 6fd233adc8..dea7b5f7fe 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -1,7 +1,16 @@ import asyncio import random import weakref -from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type +from typing import ( + Any, + AsyncIterator, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + Type, +) from redis.asyncio.client import Redis from redis.asyncio.connection import ( @@ -12,6 +21,7 @@ ) from redis.commands import AsyncSentinelCommands from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError +from redis.sentinel import ConnectionsIndexer from redis.utils import str_if_bytes @@ -26,6 +36,10 @@ class SlaveNotFoundError(ConnectionError): class SentinelManagedConnection(Connection): def __init__(self, **kwargs): self.connection_pool = kwargs.pop("connection_pool") + # To be set to True if we want to prevent + # the connection to connect to the most relevant sentinel + # in the pool and just connect to the current host and port + self._is_address_set = False super().__init__(**kwargs) def __repr__(self): @@ -39,6 +53,14 @@ def __repr__(self): s += host_info return s + ")>" + def set_address(self, address): + """ + By setting the address, the connection will just connect + to the current host and port the next time connect is called. + """ + self.host, self.port = address + self._is_address_set = True + async def connect_to(self, address): self.host, self.port = address await super().connect() @@ -50,6 +72,14 @@ async def connect_to(self, address): async def _connect_retry(self): if self._reader: return # already connected + # If address is fixed, it means that the connection + # just connect to the current host and port + if self._is_address_set: + await self.connect_to((self.host, self.port)) + return + await self._connect_to_sentinel() + + async def _connect_to_sentinel(self): if self.connection_pool.is_master: await self.connect_to(await self.connection_pool.get_master_address()) else: @@ -122,6 +152,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs): self.sentinel_manager = sentinel_manager self.master_address = None self.slave_rr_counter = None + self._iter_req_id_to_replica_address = {} def __repr__(self): return ( @@ -134,6 +165,9 @@ def reset(self): self.master_address = None self.slave_rr_counter = None + def reset_available_connections(self): + return ConnectionsIndexer() + def owns_connection(self, connection: Connection): check = not self.is_master or ( self.is_master and self.master_address == (connection.host, connection.port) @@ -167,6 +201,81 @@ async def rotate_slaves(self) -> AsyncIterator: pass raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + def cleanup(self, **options): + """ + Remove the SCAN ITER family command's request id from the dictionary + """ + self._iter_req_id_to_replica_address.pop(options.get("iter_req_id", None), None) + + async def get_connection( + self, command_name: str, *keys: Any, **options: Any + ) -> SentinelManagedConnection: + """ + Get a connection from the pool. + 'xxxscan_iter' ('scan_iter', 'hscan_iter', 'sscan_iter', 'zscan_iter') + commands needs to be handled specially. + If the client is created using a connection pool, in replica mode, + all 'scan' command-equivalent of the 'xxx_scan_iter' commands needs + to be issued to the same Redis replica. + + The way each server positions each key is different with one another, + and the cursor acts as the offset of the scan. + Hence, all scans coming from a single 'xxx_scan_iter_channel' command + should go to the same replica. + """ + # If not an iter command or in master mode, call superclass' implementation + if not (iter_req_id := options.get("iter_req_id", None)) or self.is_master: + return await super().get_connection(command_name, *keys, **options) + + # Check if this iter request has already been directed to a particular server + ( + server_host, + server_port, + ) = self._iter_req_id_to_replica_address.get(iter_req_id, (None, None)) + connection = None + # If this is the first scan request of the iter command, + # get a connection from the pool + if server_host is None or server_port is None: + try: + connection = self._available_connections.pop() + except IndexError: + connection = self.make_connection() + # If this is not the first scan request of the iter command + else: + # Get the connection that has the same host and port + connection = self._available_connections.get_connection( + host=server_host, port=server_port + ) + # If not, make a new dummy connection object, and set its host and + # port to the one that we want later in the call to ``set_address`` + if not connection: + connection = self.make_connection() + assert connection + self._in_use_connections.add(connection) + try: + # Ensure this connection is connected to Redis + # If this is the first scan request, it will + # call rotate_slaves and connect to a random replica + if server_port is None or server_port is None: + await connection.connect() + # If this is not the first scan request, + # connect to the previous replica. + # This will connect to the host and port of the replica + else: + connection.set_address((server_host, server_port)) + await self.ensure_connection(connection) + except BaseException: + # Release the connection back to the pool so that we don't + # leak it + await self.release(connection) + raise + # Store the connection to the dictionary + self._iter_req_id_to_replica_address[iter_req_id] = ( + connection.host, + connection.port, + ) + return connection + class Sentinel(AsyncSentinelCommands): """ diff --git a/redis/client.py b/redis/client.py index b7a1f88d92..c29a1d226c 100755 --- a/redis/client.py +++ b/redis/client.py @@ -581,6 +581,10 @@ def execute_command(self, *args, **options): finally: if not self.connection: pool.release(conn) + # Do additional cleanup if this is part of a SCAN ITER family command. + # It's possible that this is just a pure SCAN family command though. + if "SCAN" in command_name.upper(): + pool.cleanup(iter_req_id=options.get("iter_req_id", None)) def parse_response(self, connection, command_name, **options): """Parses a response from the Redis server""" diff --git a/redis/commands/core.py b/redis/commands/core.py index 26af9eb99c..55fe355d34 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -2,6 +2,7 @@ import datetime import hashlib +import uuid import warnings from typing import ( TYPE_CHECKING, @@ -3040,9 +3041,15 @@ def scan_iter( Additionally, Redis modules can expose other types as well. """ cursor = "0" + iter_req_id = uuid.uuid4() while cursor != 0: cursor, data = self.scan( - cursor=cursor, match=match, count=count, _type=_type, **kwargs + cursor=cursor, + match=match, + count=count, + _type=_type, + iter_req_id=iter_req_id, + **kwargs, ) yield from data @@ -3052,6 +3059,7 @@ def sscan( cursor: int = 0, match: Union[PatternT, None] = None, count: Union[int, None] = None, + **kwargs, ) -> ResponseT: """ Incrementally return lists of elements in a set. Also return a cursor @@ -3068,7 +3076,7 @@ def sscan( pieces.extend([b"MATCH", match]) if count is not None: pieces.extend([b"COUNT", count]) - return self.execute_command("SSCAN", *pieces) + return self.execute_command("SSCAN", *pieces, **kwargs) def sscan_iter( self, @@ -3085,8 +3093,11 @@ def sscan_iter( ``count`` allows for hint the minimum number of returns """ cursor = "0" + iter_req_id = uuid.uuid4() while cursor != 0: - cursor, data = self.sscan(name, cursor=cursor, match=match, count=count) + cursor, data = self.sscan( + name, cursor=cursor, match=match, count=count, iter_req_id=iter_req_id + ) yield from data def hscan( @@ -3096,6 +3107,7 @@ def hscan( match: Union[PatternT, None] = None, count: Union[int, None] = None, no_values: Union[bool, None] = None, + **kwargs, ) -> ResponseT: """ Incrementally return key/value slices in a hash. Also return a cursor @@ -3116,7 +3128,7 @@ def hscan( pieces.extend([b"COUNT", count]) if no_values is not None: pieces.extend([b"NOVALUES"]) - return self.execute_command("HSCAN", *pieces, no_values=no_values) + return self.execute_command("HSCAN", *pieces, no_values=no_values, **kwargs) def hscan_iter( self, @@ -3136,9 +3148,15 @@ def hscan_iter( ``no_values`` indicates to return only the keys, without values """ cursor = "0" + iter_req_id = uuid.uuid4() while cursor != 0: cursor, data = self.hscan( - name, cursor=cursor, match=match, count=count, no_values=no_values + name, + cursor=cursor, + match=match, + count=count, + no_values=no_values, + iter_req_id=iter_req_id, ) if no_values: yield from data @@ -3152,6 +3170,7 @@ def zscan( match: Union[PatternT, None] = None, count: Union[int, None] = None, score_cast_func: Union[type, Callable] = float, + **kwargs, ) -> ResponseT: """ Incrementally return lists of elements in a sorted set. Also return a @@ -3171,7 +3190,7 @@ def zscan( if count is not None: pieces.extend([b"COUNT", count]) options = {"score_cast_func": score_cast_func} - return self.execute_command("ZSCAN", *pieces, **options) + return self.execute_command("ZSCAN", *pieces, **options, **kwargs) def zscan_iter( self, @@ -3191,6 +3210,7 @@ def zscan_iter( ``score_cast_func`` a callable used to cast the score return value """ cursor = "0" + iter_req_id = uuid.uuid4() while cursor != 0: cursor, data = self.zscan( name, @@ -3198,6 +3218,7 @@ def zscan_iter( match=match, count=count, score_cast_func=score_cast_func, + iter_req_id=iter_req_id, ) yield from data @@ -3224,10 +3245,19 @@ async def scan_iter( HASH, LIST, SET, STREAM, STRING, ZSET Additionally, Redis modules can expose other types as well. """ + # DO NOT inline this statement to the scan call + # Each iter command should have an ID to maintain + # connection to the same replica + iter_req_id = uuid.uuid4() cursor = "0" while cursor != 0: cursor, data = await self.scan( - cursor=cursor, match=match, count=count, _type=_type, **kwargs + cursor=cursor, + match=match, + count=count, + _type=_type, + iter_req_id=iter_req_id, + **kwargs, ) for d in data: yield d @@ -3246,10 +3276,14 @@ async def sscan_iter( ``count`` allows for hint the minimum number of returns """ + # DO NOT inline this statement to the scan call + # Each iter command should have an ID to maintain + # connection to the same replica + iter_req_id = uuid.uuid4() cursor = "0" while cursor != 0: cursor, data = await self.sscan( - name, cursor=cursor, match=match, count=count + name, cursor=cursor, match=match, count=count, iter_req_id=iter_req_id ) for d in data: yield d @@ -3271,10 +3305,19 @@ async def hscan_iter( ``no_values`` indicates to return only the keys, without values """ + # DO NOT inline this statement to the scan call + # Each iter command should have an ID to maintain + # connection to the same replica + iter_req_id = uuid.uuid4() cursor = "0" while cursor != 0: cursor, data = await self.hscan( - name, cursor=cursor, match=match, count=count, no_values=no_values + name, + cursor=cursor, + match=match, + count=count, + no_values=no_values, + iter_req_id=iter_req_id, ) if no_values: for it in data: @@ -3300,6 +3343,10 @@ async def zscan_iter( ``score_cast_func`` a callable used to cast the score return value """ + # DO NOT inline this statement to the scan call + # Each iter command should have an ID to maintain + # connection to the same replica + iter_req_id = uuid.uuid4() cursor = "0" while cursor != 0: cursor, data = await self.zscan( @@ -3308,6 +3355,7 @@ async def zscan_iter( match=match, count=count, score_cast_func=score_cast_func, + iter_req_id=iter_req_id, ) for d in data: yield d diff --git a/redis/connection.py b/redis/connection.py index 19263376d2..9d501e51b7 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1104,10 +1104,18 @@ def __repr__(self) -> (str, str): f"({repr(self.connection_class(**self.connection_kwargs))})>" ) + def cleanup(self, **options): + """ + Additional cleanup operations that the connection pool might need to do. + See SentinelManagedConnection for an example cleanup operation that + might need to be done. + """ + pass + def reset(self) -> None: self._lock = threading.Lock() self._created_connections = 0 - self._available_connections = [] + self._available_connections = self.reset_available_connections() self._in_use_connections = set() # this must be the last operation in this method. while reset() is @@ -1121,6 +1129,9 @@ def reset(self) -> None: # reset() and they will immediately release _fork_lock and continue on. self.pid = os.getpid() + def reset_available_connections(self): + return [] + def _checkpid(self) -> None: # _checkpid() attempts to keep ConnectionPool fork-safe on modern # systems. this is called by all ConnectionPool methods that @@ -1168,6 +1179,25 @@ def _checkpid(self) -> None: finally: self._fork_lock.release() + def ensure_connection(self, connection: AbstractConnection): + # ensure this connection is connected to Redis + connection.connect() + # if client caching is not enabled connections that the pool + # provides should be ready to send a command. + # if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + # (if caching enabled the connection will not always be ready + # to send a command because it may contain invalidation messages) + try: + if connection.can_read() and connection.client_cache is None: + raise ConnectionError("Connection has data") + except (ConnectionError, OSError): + connection.disconnect() + connection.connect() + if connection.can_read(): + raise ConnectionError("Connection not ready") + def get_connection(self, command_name: str, *keys, **options) -> "Connection": "Get a connection from the pool" self._checkpid() @@ -1179,23 +1209,7 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection": self._in_use_connections.add(connection) try: - # ensure this connection is connected to Redis - connection.connect() - # if client caching is not enabled connections that the pool - # provides should be ready to send a command. - # if not, the connection was either returned to the - # pool before all data has been read or the socket has been - # closed. either way, reconnect and verify everything is good. - # (if caching enabled the connection will not always be ready - # to send a command because it may contain invalidation messages) - try: - if connection.can_read() and connection.client_cache is None: - raise ConnectionError("Connection has data") - except (ConnectionError, OSError): - connection.disconnect() - connection.connect() - if connection.can_read(): - raise ConnectionError("Connection not ready") + self.ensure_connection(connection) except BaseException: # release the connection back to the pool so that we don't # leak it diff --git a/redis/sentinel.py b/redis/sentinel.py index 72b5bef548..768c7c245a 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -1,6 +1,7 @@ import random import weakref -from typing import Optional +from collections import defaultdict +from typing import Any, Iterable, Optional from redis.client import Redis from redis.commands import SentinelCommands @@ -20,6 +21,11 @@ class SlaveNotFoundError(ConnectionError): class SentinelManagedConnection(Connection): def __init__(self, **kwargs): self.connection_pool = kwargs.pop("connection_pool") + # To be set to True if we want to prevent + # the sentinel managed connection to connect + # to the most relevant sentinel in the pool and just + # connect to the current self.host and self.port + self._is_address_set = False super().__init__(**kwargs) def __repr__(self): @@ -33,6 +39,14 @@ def __repr__(self): s = s % host_info return s + def set_address(self, address): + """ + By setting the address, the connection will just connect + to the current host and port the next time connect is called. + """ + self.host, self.port = address + self._is_address_set = True + def connect_to(self, address): self.host, self.port = address super().connect() @@ -44,6 +58,14 @@ def connect_to(self, address): def _connect_retry(self): if self._sock: return # already connected + # If address is set, it means that the connection + # will just connect to the current host and port. + if self._is_address_set: + self.connect_to((self.host, self.port)) + return + self._connect_to_sentinel() + + def _connect_to_sentinel(self): if self.connection_pool.is_master: self.connect_to(self.connection_pool.get_master_address()) else: @@ -55,7 +77,9 @@ def _connect_retry(self): raise SlaveNotFoundError # Never be here def connect(self): - return self.retry.call_with_retry(self._connect_retry, lambda error: None) + return self.retry.call_with_retry( + lambda: self._connect_retry(), lambda error: None + ) def read_response( self, @@ -134,6 +158,57 @@ def rotate_slaves(self): raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") +class ConnectionsIndexer(Iterable): + """ + Data structure that stores available connections in a pool. + Instead of list, we keep 2 additional DS to support O(1) operations + on all of the class' methods. + The first DS is indexed on the connection object's ID. + The second DS is indexed on the address (ip and port) of the connection. + """ + + def __init__(self): + # Map the id to the connection object + self._id_to_connection = {} + # Map the address to a dictionary of connections + # The inner dictionary is a map between the object id to the object itself + # Both of these DS support O(1) operations on all of the class' methods + self._address_to_connections = defaultdict(dict) + + def pop(self): + try: + _, connection = self._id_to_connection.popitem() + del self._address_to_connections[(connection.host, connection.port)][ + id(connection) + ] + except KeyError: + # We are simulating a list, hence we raise IndexError + # when there's no item in the dictionary + raise IndexError() + return connection + + def append(self, connection: Connection): + self._id_to_connection[id(connection)] = connection + self._address_to_connections[(connection.host, connection.port)][ + id(connection) + ] = connection + + def get_connection(self, host: str, port: int): + try: + _, connection = self._address_to_connections[(host, port)].popitem() + del self._id_to_connection[id(connection)] + except KeyError: + return None + return connection + + def __iter__(self): + # This is an O(1) operation in python3.7 and later + return iter(self._id_to_connection.values()) + + def __len__(self): + return len(self._id_to_connection) + + class SentinelConnectionPool(ConnectionPool): """ Sentinel backed connection pool. @@ -164,6 +239,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs): self.connection_kwargs["connection_pool"] = self.proxy self.service_name = service_name self.sentinel_manager = sentinel_manager + self._iter_req_id_to_replica_address = {} def __repr__(self): role = "master" if self.is_master else "slave" @@ -176,6 +252,9 @@ def reset(self): super().reset() self.proxy.reset() + def reset_available_connections(self): + return ConnectionsIndexer() + @property def master_address(self): return self.proxy.master_address @@ -194,6 +273,81 @@ def rotate_slaves(self): "Round-robin slave balancer" return self.proxy.rotate_slaves() + def cleanup(self, **options): + """ + Remove the SCAN ITER family command's request id from the dictionary + """ + self._iter_req_id_to_replica_address.pop(options.get("iter_req_id", None), None) + + def get_connection( + self, command_name: str, *keys: Any, **options: Any + ) -> SentinelManagedConnection: + """ + Get a connection from the pool. + 'xxxscan_iter' ('scan_iter', 'hscan_iter', 'sscan_iter', 'zscan_iter') + commands needs to be handled specially. + If the client is created using a connection pool, in replica mode, + all 'scan' command-equivalent of the 'xxx_scan_iter' commands needs + to be issued to the same Redis replica. + + The way each server positions each key is different with one another, + and the cursor acts as the offset of the scan. + Hence, all scans coming from a single 'xxx_scan_iter_channel' command + should go to the same replica. + """ + # If not an iter command or in master mode, call superclass' implementation + if not (iter_req_id := options.get("iter_req_id", None)) or self.is_master: + return super().get_connection(command_name, *keys, **options) + + # Check if this iter request has already been directed to a particular server + ( + server_host, + server_port, + ) = self._iter_req_id_to_replica_address.get(iter_req_id, (None, None)) + connection = None + # If this is the first scan request of the iter command, + # get a connection from the pool + if server_host is None or server_port is None: + try: + connection = self._available_connections.pop() + except IndexError: + connection = self.make_connection() + # If this is not the first scan request of the iter command + else: + # Get the connection that has the same host and port + connection = self._available_connections.get_connection( + host=server_host, port=server_port + ) + # If not, make a new dummy connection object, and set its host and port + # to the one that we want later in the call to ``set_address`` + if not connection: + connection = self.make_connection() + assert connection + self._in_use_connections.add(connection) + try: + # Ensure this connection is connected to Redis + # If this is the first scan request, it will + # call rotate_slaves and connect to a random replica + if server_port is None or server_port is None: + connection.connect() + # If this is not the first scan request, + # connect to the previous replica. + # This will connect to the host and port of the replica + else: + connection.set_address((server_host, server_port)) + self.ensure_connection(connection) + except BaseException: + # Release the connection back to the pool so that we don't + # leak it + self.release(connection) + raise + # Store the connection to the dictionary + self._iter_req_id_to_replica_address[iter_req_id] = ( + connection.host, + connection.port, + ) + return connection + class Sentinel(SentinelCommands): """ diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 61c00541cb..52967e26e0 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -1402,6 +1402,39 @@ async def test_zscan_iter(self, r: redis.Redis): pairs = [k async for k in r.zscan_iter("a", match="a")] assert set(pairs) == {(b"a", 1)} + async def test_scan_iter_family_executes_commands_with_same_iter_req_id(self): + """Assert that all calls to execute_command receives the iter_req_id kwarg""" + import uuid + + from redis.commands.core import AsyncScanCommands + + from .compat import mock + + with mock.patch.object( + AsyncScanCommands, "execute_command", mock.AsyncMock(return_value=(0, [])) + ) as mock_execute_command, mock.patch.object( + uuid, "uuid4", return_value="uuid" + ): + [a async for a in AsyncScanCommands().scan_iter()] + mock_execute_command.assert_called_with("SCAN", "0", iter_req_id="uuid") + [a async for a in AsyncScanCommands().sscan_iter("")] + mock_execute_command.assert_called_with( + "SSCAN", "", "0", iter_req_id="uuid" + ) + with mock.patch.object( + AsyncScanCommands, "execute_command", mock.AsyncMock(return_value=(0, {})) + ) as mock_execute_command, mock.patch.object( + uuid, "uuid4", return_value="uuid" + ): + [a async for a in AsyncScanCommands().hscan_iter("")] + mock_execute_command.assert_called_with( + "HSCAN", "", "0", no_values=None, iter_req_id="uuid" + ) + [a async for a in AsyncScanCommands().zscan_iter("")] + mock_execute_command.assert_called_with( + "ZSCAN", "", "0", score_cast_func=mock.ANY, iter_req_id="uuid" + ) + # SET COMMANDS async def test_sadd(self, r: redis.Redis): members = {b"1", b"2", b"3"} diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 5e4d3f206f..97b65f152f 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -1,5 +1,6 @@ import asyncio import re +from itertools import chain import pytest import pytest_asyncio @@ -35,7 +36,7 @@ async def create_two_conn(r: redis.Redis): def has_no_connected_connections(pool: redis.ConnectionPool): return not any( x.is_connected - for x in pool._available_connections + list(pool._in_use_connections) + for x in chain(pool._available_connections, pool._in_use_connections) ) async def test_auto_disconnect_redis_created_pool(self, r: redis.Redis): @@ -56,7 +57,7 @@ async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): assert r2.connection_pool._in_use_connections == {new_conn} assert new_conn.is_connected assert len(r2.connection_pool._available_connections) == 1 - assert r2.connection_pool._available_connections[0].is_connected + assert list(r2.connection_pool._available_connections)[0].is_connected async def test_auto_release_override_true_manual_created_pool(self, r: redis.Redis): assert r.auto_close_connection_pool is True, "This is from the class fixture" @@ -84,7 +85,7 @@ async def test_negate_auto_close_client_pool( await r.aclose(close_connection_pool=False) assert not self.has_no_connected_connections(r.connection_pool) assert r.connection_pool._in_use_connections == {new_conn} - assert r.connection_pool._available_connections[0].is_connected + assert list(r.connection_pool._available_connections)[0].is_connected assert self.get_total_connected_connections(r.connection_pool) == 2 @@ -93,6 +94,8 @@ class DummyConnection(Connection): def __init__(self, **kwargs): self.kwargs = kwargs + self.host = kwargs.get("host", None) + self.port = kwargs.get("port", None) def repr_pieces(self): return [("id", id(self)), ("kwargs", self.kwargs)] @@ -578,7 +581,7 @@ async def test_on_connect_error(self): await bad_connection.info() pool = bad_connection.connection_pool assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._reader + assert not list(pool._available_connections)[0]._reader @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.8") @@ -609,7 +612,7 @@ async def test_busy_loading_from_pipeline_immediate_command(self, r): pool = r.connection_pool assert not pipe.connection assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._reader + assert not list(pool._available_connections)[0]._reader @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.8") @@ -626,7 +629,7 @@ async def test_busy_loading_from_pipeline(self, r): pool = r.connection_pool assert not pipe.connection assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._reader + assert not list(pool._available_connections)[0]._reader @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise() diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index cae4b9581f..822756b81e 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -1,9 +1,15 @@ import socket +from typing import Iterator, Tuple import pytest from redis.asyncio.retry import Retry -from redis.asyncio.sentinel import SentinelManagedConnection +from redis.asyncio.sentinel import ( + Sentinel, + SentinelConnectionPool, + SentinelManagedConnection, +) from redis.backoff import NoBackoff +from redis.utils import HIREDIS_AVAILABLE from .compat import mock @@ -35,3 +41,216 @@ async def mock_connect(): await conn.connect() assert conn._connect.call_count == 3 await conn.disconnect() + + +class SentinelManagedConnectionMock(SentinelManagedConnection): + async def _connect_to_sentinel(self) -> None: + """ + This simulates the behavior of _connect_to_sentinel when + :py:class:`~redis.asyncio.sentinel.SentinelConnectionPool`. + In master mode, it'll connect to the master. + In non-master mode, it'll call rotate_slaves and connect to the next replica. + """ + if self.connection_pool.is_master: + self.host, self.port = ("master", 1) + else: + import random + import time + + self.host = f"host-{random.randint(0, 10)}" + self.port = time.time() + + async def connect_to(self, address: Tuple[str, int]) -> None: + """ + Do nothing, just mock so it won't try to make a connection to the + dummy address. + """ + + +@pytest.fixture() +def connection_pool_replica_mock() -> Iterator[SentinelConnectionPool]: + sentinel_manager = Sentinel([["master", 400]]) + # Give a random slave + sentinel_manager.discover_slaves = mock.AsyncMock(return_value=["replica", 5000]) + with mock.patch( + "redis._parsers._AsyncRESP2Parser.can_read_destructive", return_value=False + ): + # Create connection pool with our mock connection object + connection_pool = SentinelConnectionPool( + "usasm", + sentinel_manager, + is_master=False, + connection_class=SentinelManagedConnectionMock, + ) + yield connection_pool + + +@pytest.fixture() +def connection_pool_master_mock() -> Iterator[SentinelConnectionPool]: + sentinel_manager = Sentinel([["master", 400]]) + # Give a random slave + sentinel_manager.discover_master = mock.AsyncMock(return_value=["replica", 5000]) + with mock.patch( + "redis._parsers._AsyncRESP2Parser.can_read_destructive", return_value=False + ): + # Create connection pool with our mock connection object + connection_pool = SentinelConnectionPool( + "usasm", + sentinel_manager, + is_master=True, + connection_class=SentinelManagedConnectionMock, + ) + yield connection_pool + + +def same_address( + connection_1: SentinelManagedConnection, + connection_2: SentinelManagedConnection, +) -> bool: + return bool( + connection_1.host == connection_2.host + and connection_1.port == connection_2.port + ) + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_connects_to_same_address_if_same_id_replica( + connection_pool_replica_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection address is the same if the ``iter_req_id`` is the same + when we are in replica mode using a + :py:class:`~redis.asyncio.sentinel.SentinelConnectionPool`. + """ + connection_for_req_1 = await connection_pool_replica_mock.get_connection( + "ANY", iter_req_id=1 + ) + assert same_address( + await connection_pool_replica_mock.get_connection("ANY", iter_req_id=1), + connection_for_req_1, + ) + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_connects_to_same_conn_object_if_same_id_and_conn_released_replica( + connection_pool_replica_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection object is the same if the ``iter_req_id`` is the same + when we are in replica mode using a + :py:class:`~redis.asyncio.sentinel.SentinelConnectionPool` + and if we release the connection back to the connection pool before + trying to connect again. + """ + connection_for_req_1 = await connection_pool_replica_mock.get_connection( + "ANY", iter_req_id=1 + ) + await connection_pool_replica_mock.release(connection_for_req_1) + assert ( + await connection_pool_replica_mock.get_connection("ANY", iter_req_id=1) + == connection_for_req_1 + ) + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_connects_to_diff_address_if_no_iter_req_id_replica( + connection_pool_replica_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection object is different if no iter_req_id is supplied. + In reality, they can be the same, but in this case, we're not + releasing the connection to the pool so they should always be different. + """ + connection_for_req_1 = await connection_pool_replica_mock.get_connection( + "ANY", iter_req_id=1 + ) + connection_for_random_req = await connection_pool_replica_mock.get_connection( + "ANYY" + ) + assert not same_address(connection_for_random_req, connection_for_req_1) + assert not same_address( + await connection_pool_replica_mock.get_connection("ANY_COMMAND"), + connection_for_random_req, + ) + assert not same_address( + await connection_pool_replica_mock.get_connection("ANY_COMMAND"), + connection_for_req_1, + ) + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_connects_to_same_address_if_same_iter_req_id_master( + connection_pool_master_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection address is the same if the ``iter_req_id`` is the same + when we are in master mode using a + :py:class:`~redis.asyncio.sentinel.SentinelConnectionPool`. + """ + connection_for_req_1 = await connection_pool_master_mock.get_connection( + "ANY", iter_req_id=1 + ) + assert same_address( + await connection_pool_master_mock.get_connection("ANY", iter_req_id=1), + connection_for_req_1, + ) + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_connects_to_same_conn_object_if_same_iter_req_id_and_released_master( + connection_pool_master_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection address is the same if the ``iter_req_id`` is the same + when we are in master mode using a + :py:class:`~redis.asyncio.sentinel.SentinelConnectionPool` + and if we release the connection back to the connection pool before + trying to connect again. + """ + connection_for_req_1 = await connection_pool_master_mock.get_connection( + "ANY", iter_req_id=1 + ) + assert same_address( + await connection_pool_master_mock.get_connection("ANY", iter_req_id=1), + connection_for_req_1, + ) + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_connects_to_same_address_if_no_iter_req_id_master( + connection_pool_master_mock: SentinelConnectionPool, +) -> None: + """ + Assert that connection address is always the same regardless if + there's an ``iter_req_id`` or not + when we are in master mode using a + :py:class:`~redis.asyncio.sentinel.SentinelConnectionPool`. + """ + connection_for_req_1 = await connection_pool_master_mock.get_connection( + "ANY", iter_req_id=1 + ) + connection_for_random_req = await connection_pool_master_mock.get_connection("ANYY") + assert same_address(connection_for_random_req, connection_for_req_1) + assert same_address( + await connection_pool_master_mock.get_connection("ANY_COMMAND"), + connection_for_random_req, + ) + + assert same_address( + await connection_pool_master_mock.get_connection("ANY_COMMAND"), + connection_for_req_1, + ) + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_scan_iter_family_cleans_up( + connection_pool_replica_mock: SentinelConnectionPool, +): + """Test that connection pool is correctly cleaned up""" + from redis.asyncio import Redis + + r = Redis(connection_pool=connection_pool_replica_mock) + + with mock.patch.object(r, "_send_command_parse_response", return_value=(0, [])): + [k async for k in r.scan_iter("a")] + assert not connection_pool_replica_mock._iter_req_id_to_replica_address diff --git a/tests/test_commands.py b/tests/test_commands.py index 42376b50d8..c26734bf42 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -2235,6 +2235,37 @@ def test_zscan_iter(self, r): pairs = list(r.zscan_iter("a", match="a")) assert set(pairs) == {(b"a", 1)} + def test_scan_iter_family_executes_commands_with_same_iter_req_id(self): + """Assert that all calls to execute_command receives the iter_req_id kwarg""" + import uuid + + from redis.commands.core import ScanCommands + + with mock.patch.object( + ScanCommands, "execute_command", mock.Mock(return_value=(0, [])) + ) as mock_execute_command, mock.patch.object( + uuid, "uuid4", return_value="uuid" + ): + [a for a in ScanCommands().scan_iter()] + mock_execute_command.assert_called_with("SCAN", "0", iter_req_id="uuid") + [a for a in ScanCommands().sscan_iter("")] + mock_execute_command.assert_called_with( + "SSCAN", "", "0", iter_req_id="uuid" + ) + with mock.patch.object( + ScanCommands, "execute_command", mock.Mock(return_value=(0, {})) + ) as mock_execute_command, mock.patch.object( + uuid, "uuid4", return_value="uuid" + ): + [a for a in ScanCommands().hscan_iter("")] + mock_execute_command.assert_called_with( + "HSCAN", "", "0", no_values=None, iter_req_id="uuid" + ) + [a for a in ScanCommands().zscan_iter("")] + mock_execute_command.assert_called_with( + "ZSCAN", "", "0", score_cast_func=mock.ANY, iter_req_id="uuid" + ) + # SET COMMANDS def test_sadd(self, r): members = {b"1", b"2", b"3"} diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index dee7c554d3..313dd96287 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -20,6 +20,8 @@ class DummyConnection: def __init__(self, **kwargs): self.kwargs = kwargs self.pid = os.getpid() + self.host = kwargs.get("host", None) + self.port = kwargs.get("port", None) def connect(self): pass @@ -502,7 +504,7 @@ def test_on_connect_error(self): bad_connection.info() pool = bad_connection.connection_pool assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._sock + assert not list(pool._available_connections)[0]._sock @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.8") @@ -530,7 +532,7 @@ def test_busy_loading_from_pipeline_immediate_command(self, r): pool = r.connection_pool assert not pipe.connection assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._sock + assert not list(pool._available_connections)[0]._sock @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.8") @@ -547,7 +549,7 @@ def test_busy_loading_from_pipeline(self, r): pool = r.connection_pool assert not pipe.connection assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._sock + assert not list(pool._available_connections)[0]._sock @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise() diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 54b9647098..9b5bce0fea 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -4,7 +4,9 @@ import pytest import redis.sentinel from redis import exceptions +from redis.connection import Connection from redis.sentinel import ( + ConnectionsIndexer, MasterNotFoundError, Sentinel, SentinelConnectionPool, @@ -266,3 +268,18 @@ def mock_disconnect(): assert calls == 1 pool.disconnect() + + +def test_connections_indexer_operations(): + ci = ConnectionsIndexer() + c1 = Connection(host="1", port=2) + ci.append(c1) + assert list(ci) == [c1] + assert ci.pop() == c1 + + c2 = Connection(host="3", port=4) + ci.append(c1) + ci.append(c2) + assert ci.get_connection("3", 4) == c2 + assert not ci.get_connection("5", 6) + assert list(ci) == [c1] diff --git a/tests/test_sentinel_managed_connection.py b/tests/test_sentinel_managed_connection.py new file mode 100644 index 0000000000..f59ff3f3c2 --- /dev/null +++ b/tests/test_sentinel_managed_connection.py @@ -0,0 +1,209 @@ +from typing import Tuple +from unittest import mock + +import pytest +from redis import Redis +from redis.sentinel import Sentinel, SentinelConnectionPool, SentinelManagedConnection +from redis.utils import HIREDIS_AVAILABLE + +pytestmark = pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + + +class SentinelManagedConnectionMock(SentinelManagedConnection): + def _connect_to_sentinel(self) -> None: + """ + This simulates the behavior of _connect_to_sentinel when + :py:class:`~redis.SentinelConnectionPool`. + In master mode, it'll connect to the master. + In non-master mode, it'll call rotate_slaves and connect to the next replica. + """ + if self.connection_pool.is_master: + self.host, self.port = ("master", 1) + else: + import random + import time + + self.host = f"host-{random.randint(0, 10)}" + self.port = time.time() + + def connect_to(self, address: Tuple[str, int]) -> None: + """ + Do nothing, just mock so it won't try to make a connection to the + dummy address. + """ + pass + + +@pytest.fixture() +def connection_pool_replica_mock() -> SentinelConnectionPool: + sentinel_manager = Sentinel([["master", 400]]) + # Give a random slave + sentinel_manager.discover_slaves = mock.Mock(return_value=["replica", 5000]) + # Create connection pool with our mock connection object + connection_pool = SentinelConnectionPool( + "usasm", + sentinel_manager, + is_master=False, + connection_class=SentinelManagedConnectionMock, + ) + return connection_pool + + +@pytest.fixture() +def connection_pool_master_mock() -> SentinelConnectionPool: + sentinel_manager = Sentinel([["master", 400]]) + # Give a random slave + sentinel_manager.discover_master = mock.Mock(return_value=["replica", 5000]) + # Create connection pool with our mock connection object + connection_pool = SentinelConnectionPool( + "usasm", + sentinel_manager, + is_master=True, + connection_class=SentinelManagedConnectionMock, + ) + return connection_pool + + +def same_address( + connection_1: SentinelManagedConnection, + connection_2: SentinelManagedConnection, +) -> bool: + return bool( + connection_1.host == connection_2.host + and connection_1.port == connection_2.port + ) + + +def test_connects_to_same_address_if_same_id_replica( + connection_pool_replica_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection address is the same if the ``iter_req_id`` is the same + when we are in replica mode using a + :py:class:`~redis.sentinel.SentinelConnectionPool`. + """ + connection_for_req_1 = connection_pool_replica_mock.get_connection( + "ANY", iter_req_id=1 + ) + assert same_address( + connection_pool_replica_mock.get_connection("ANY", iter_req_id=1), + connection_for_req_1, + ) + + +def test_connects_to_same_conn_object_if_same_id_and_conn_released_replica( + connection_pool_replica_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection object is the same if the ``iter_req_id`` is the same + when we are in replica mode using a + :py:class:`~redis.sentinel.SentinelConnectionPool` + and if we release the connection back to the connection pool before + trying to connect again. + """ + connection_for_req_1 = connection_pool_replica_mock.get_connection( + "ANY", iter_req_id=1 + ) + connection_pool_replica_mock.release(connection_for_req_1) + assert ( + connection_pool_replica_mock.get_connection("ANY", iter_req_id=1) + == connection_for_req_1 + ) + + +def test_connects_to_diff_address_if_no_iter_req_id_replica( + connection_pool_replica_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection object is different if no iter_req_id is supplied. + In reality, they can be the same, but in this case, we're not + releasing the connection to the pool so they should always be different. + """ + connection_for_req_1 = connection_pool_replica_mock.get_connection( + "ANY", iter_req_id=1 + ) + connection_for_random_req = connection_pool_replica_mock.get_connection("ANYY") + assert not same_address(connection_for_random_req, connection_for_req_1) + assert not same_address( + connection_pool_replica_mock.get_connection("ANY_COMMAND"), + connection_for_random_req, + ) + assert not same_address( + connection_pool_replica_mock.get_connection("ANY_COMMAND"), + connection_for_req_1, + ) + + +def test_connects_to_same_address_if_same_iter_req_id_master( + connection_pool_master_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection address is the same if the ``_iter_req_id`` is the same + when we are in master mode using a + :py:class:`~redis.sentinel.SentinelConnectionPool`. + """ + connection_for_req_1 = connection_pool_master_mock.get_connection( + "ANY", _iter_req_id=1 + ) + assert same_address( + connection_pool_master_mock.get_connection("ANY", iter_req_id=1), + connection_for_req_1, + ) + + +def test_connects_to_same_conn_object_if_same_iter_req_id_and_released_master( + connection_pool_master_mock: SentinelConnectionPool, +) -> None: + """ + Assert that the connection address is the same if the ``_iter_req_id`` is the same + when we are in master mode using a + :py:class:`~redis.sentinel.SentinelConnectionPool` + and if we release the connection back to the connection pool before + trying to connect again. + """ + connection_for_req_1 = connection_pool_master_mock.get_connection( + "ANY", iter_req_id=1 + ) + assert same_address( + connection_pool_master_mock.get_connection("ANY", iter_req_id=1), + connection_for_req_1, + ) + + +def test_connects_to_same_address_if_no_iter_req_id_master( + connection_pool_master_mock: SentinelConnectionPool, +) -> None: + """ + Assert that connection address is always the same regardless if + there's an ``iter_req_id`` or not + when we are in master mode using a + :py:class:`~redis.sentinel.SentinelConnectionPool`. + """ + connection_for_req_1 = connection_pool_master_mock.get_connection( + "ANY", iter_req_id=1 + ) + connection_for_random_req = connection_pool_master_mock.get_connection("ANYY") + assert same_address(connection_for_random_req, connection_for_req_1) + assert same_address( + connection_pool_master_mock.get_connection("ANY_COMMAND"), + connection_for_random_req, + ) + + assert same_address( + connection_pool_master_mock.get_connection("ANY_COMMAND"), + connection_for_req_1, + ) + + +def test_scan_iter_in_redis_cleans_up( + connection_pool_replica_mock: SentinelConnectionPool, +): + """Test that connection pool is correctly cleaned up""" + r = Redis(connection_pool=connection_pool_replica_mock) + # Patch the actual sending and parsing response from the Connection object + # but still let the connection pool does all the necessary work + with mock.patch.object(r, "_send_command_parse_response", return_value=(0, [])): + [k for k in r.scan_iter("a")] + # Test that the iter_req_id for the scan command is cleared at the + # end of the SCAN ITER command + assert not connection_pool_replica_mock._iter_req_id_to_replica_address