diff --git a/.flake8 b/.flake8 index 0e0ace6a4a..6c663473e4 100644 --- a/.flake8 +++ b/.flake8 @@ -15,7 +15,12 @@ exclude = whitelist.py, tasks.py ignore = + E126 + E203 F405 + N801 + N802 + N803 + N806 + N815 W503 - E203 - E126 \ No newline at end of file diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py index ab4ede1fd0..fb5da831fe 100644 --- a/redis/_parsers/helpers.py +++ b/redis/_parsers/helpers.py @@ -629,8 +629,7 @@ def parse_client_info(value): "key1=value1 key2=value2 key3=value3" """ client_info = {} - infos = str_if_bytes(value).split(" ") - for info in infos: + for info in str_if_bytes(value).strip().split(): key, value = info.split("=") client_info[key] = value @@ -700,6 +699,7 @@ def string_keys_to_dict(key_string, callback): "CLIENT KILL": parse_client_kill, "CLIENT LIST": parse_client_list, "CLIENT PAUSE": bool_ok, + "CLIENT SETINFO": bool_ok, "CLIENT SETNAME": bool_ok, "CLIENT UNBLOCK": bool, "CLUSTER ADDSLOTS": bool_ok, diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 0e3c879278..f0c1ab7536 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -62,7 +62,13 @@ WatchError, ) from redis.typing import ChannelT, EncodableT, KeyT -from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes +from redis.utils import ( + HIREDIS_AVAILABLE, + _set_info_logger, + get_lib_version, + safe_str, + str_if_bytes, +) PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) @@ -190,6 +196,8 @@ def __init__( single_connection_client: bool = False, health_check_interval: int = 0, client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, retry: Optional[Retry] = None, auto_close_connection_pool: bool = True, @@ -232,6 +240,8 @@ def __init__( "max_connections": max_connections, "health_check_interval": health_check_interval, "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, "redis_connect_func": redis_connect_func, "protocol": protocol, } diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 9e2a40ce1b..84407116ed 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -62,7 +62,7 @@ TryAgainError, ) from redis.typing import AnyKeyT, EncodableT, KeyT -from redis.utils import dict_merge, safe_str, str_if_bytes +from redis.utils import dict_merge, get_lib_version, safe_str, str_if_bytes TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -237,6 +237,8 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), # Encoding related kwargs encoding: str = "utf-8", encoding_errors: str = "strict", @@ -288,6 +290,8 @@ def __init__( "username": username, "password": password, "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, # Encoding related kwargs "encoding": encoding, "encoding_errors": encoding_errors, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d501989c83..c1cc1d310c 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -49,7 +49,7 @@ TimeoutError, ) from redis.typing import EncodableT -from redis.utils import HIREDIS_AVAILABLE, str_if_bytes +from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes from .._parsers import ( BaseParser, @@ -101,6 +101,8 @@ class AbstractConnection: "db", "username", "client_name", + "lib_name", + "lib_version", "credential_provider", "password", "socket_timeout", @@ -140,6 +142,8 @@ def __init__( socket_read_size: int = 65536, health_check_interval: float = 0, client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, retry: Optional[Retry] = None, redis_connect_func: Optional[ConnectCallbackT] = None, @@ -157,6 +161,8 @@ def __init__( self.pid = os.getpid() self.db = db self.client_name = client_name + self.lib_name = lib_name + self.lib_version = lib_version self.credential_provider = credential_provider self.password = password self.username = username @@ -347,9 +353,23 @@ async def on_connect(self) -> None: if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Error setting client name") - # if a database is specified, switch to it + # set the library name and version, pipeline for lower startup latency + if self.lib_name: + await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + if self.lib_version: + await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + # if a database is specified, switch to it. Also pipeline this if self.db: await self.send_command("SELECT", self.db) + + # read responses from pipeline + for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): + try: + await self.read_response() + except ResponseError: + pass + + if self.db: if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Invalid Database") diff --git a/redis/client.py b/redis/client.py index a856ef84ad..f695cef534 100755 --- a/redis/client.py +++ b/redis/client.py @@ -31,7 +31,13 @@ ) from redis.lock import Lock from redis.retry import Retry -from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes +from redis.utils import ( + HIREDIS_AVAILABLE, + _set_info_logger, + get_lib_version, + safe_str, + str_if_bytes, +) SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -171,6 +177,8 @@ def __init__( single_connection_client=False, health_check_interval=0, client_name=None, + lib_name="redis-py", + lib_version=get_lib_version(), username=None, retry=None, redis_connect_func=None, @@ -222,6 +230,8 @@ def __init__( "max_connections": max_connections, "health_check_interval": health_check_interval, "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, "protocol": protocol, diff --git a/redis/cluster.py b/redis/cluster.py index 0c33fd2c68..1ffa5ff547 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -143,6 +143,8 @@ def parse_cluster_myshardid(resp, **options): "encoding_errors", "errors", "host", + "lib_name", + "lib_version", "max_connections", "nodes_flag", "redis_connect_func", @@ -225,6 +227,7 @@ class AbstractRedisCluster: "ACL WHOAMI", "AUTH", "CLIENT LIST", + "CLIENT SETINFO", "CLIENT SETNAME", "CLIENT GETNAME", "CONFIG SET", diff --git a/redis/commands/core.py b/redis/commands/core.py index 09ec59f47c..031781d75d 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -708,6 +708,13 @@ def client_setname(self, name: str, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT SETNAME", name, **kwargs) + def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT: + """ + Sets the current connection library name or version + For mor information see https://redis.io/commands/client-setinfo + """ + return self.execute_command("CLIENT SETINFO", attr, value, **kwargs) + def client_unblock( self, client_id: int, error: bool = False, **kwargs ) -> ResponseT: diff --git a/redis/connection.py b/redis/connection.py index 66debed2ea..00d293a238 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -31,6 +31,7 @@ HIREDIS_AVAILABLE, HIREDIS_PACK_AVAILABLE, SSL_AVAILABLE, + get_lib_version, str_if_bytes, ) @@ -140,6 +141,8 @@ def __init__( socket_read_size=65536, health_check_interval=0, client_name=None, + lib_name="redis-py", + lib_version=get_lib_version(), username=None, retry=None, redis_connect_func=None, @@ -164,6 +167,8 @@ def __init__( self.pid = os.getpid() self.db = db self.client_name = client_name + self.lib_name = lib_name + self.lib_version = lib_version self.credential_provider = credential_provider self.password = password self.username = username @@ -360,6 +365,21 @@ def on_connect(self): if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Error setting client name") + try: + # set the library name and version + if self.lib_name: + self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + self.read_response() + except ResponseError: + pass + + try: + if self.lib_version: + self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + self.read_response() + except ResponseError: + pass + # if a database is specified, switch to it if self.db: self.send_command("SELECT", self.db) diff --git a/redis/utils.py b/redis/utils.py index 148d15246b..01fdfed7a2 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,4 +1,5 @@ import logging +import sys from contextlib import contextmanager from functools import wraps from typing import Any, Dict, Mapping, Union @@ -27,6 +28,11 @@ except ImportError: CRYPTOGRAPHY_AVAILABLE = False +if sys.version_info >= (3, 8): + from importlib import metadata +else: + import importlib_metadata as metadata + def from_url(url, **kwargs): """ @@ -131,3 +137,11 @@ def _set_info_logger(): handler = logging.StreamHandler() handler.setLevel(logging.INFO) logger.addHandler(handler) + + +def get_lib_version(): + try: + libver = metadata.version("redis") + except metadata.PackageNotFoundError: + libver = "99.99.99" + return libver diff --git a/tests/conftest.py b/tests/conftest.py index b3c410e51b..16f3fbb9db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -131,6 +131,7 @@ def pytest_sessionstart(session): enterprise = info["enterprise"] except redis.ConnectionError: # provide optimistic defaults + info = {} version = "10.0.0" arch_bits = 64 cluster_enabled = False @@ -145,9 +146,7 @@ def pytest_sessionstart(session): # module info try: REDIS_INFO["modules"] = info["modules"] - except redis.exceptions.ConnectionError: - pass - except KeyError: + except (KeyError, redis.exceptions.ConnectionError): pass if cluster_enabled: diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index e5da3f8f46..c837f284f7 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -253,3 +253,15 @@ async def __aexit__(self, exc_type, exc_inst, tb): def asynccontextmanager(func): return _asynccontextmanager(func) + + +# helpers to get the connection arguments for this run +@pytest.fixture() +def redis_url(request): + return request.config.getoption("--redis-url") + + +@pytest.fixture() +def connect_args(request): + url = request.config.getoption("--redis-url") + return parse_url(url) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index eb7aafdf68..1cb1fa5195 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2294,7 +2294,7 @@ async def test_acl_log( await user_client.hset("{cache}:0", "hkey", "hval") assert isinstance(await r.acl_log(target_nodes=node), list) - assert len(await r.acl_log(target_nodes=node)) == 2 + assert len(await r.acl_log(target_nodes=node)) == 3 assert len(await r.acl_log(count=1, target_nodes=node)) == 1 assert isinstance((await r.acl_log(target_nodes=node))[0], dict) assert "client-info" in (await r.acl_log(count=1, target_nodes=node))[0] diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 08e66b050f..7808d171fa 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -273,7 +273,7 @@ async def test_acl_log(self, r_teardown, create_redis): await user_client.hset("cache:0", "hkey", "hval") assert isinstance(await r.acl_log(), list) - assert len(await r.acl_log()) == 2 + assert len(await r.acl_log()) == 3 assert len(await r.acl_log(count=1)) == 1 assert isinstance((await r.acl_log())[0], dict) expected = (await r.acl_log(count=1))[0] @@ -355,6 +355,26 @@ async def test_client_setname(self, r: redis.Redis): r, await r.client_getname(), "redis_py_test", b"redis_py_test" ) + @skip_if_server_version_lt("7.2.0") + async def test_client_setinfo(self, r: redis.Redis): + await r.ping() + info = await r.client_info() + assert info["lib-name"] == "redis-py" + assert info["lib-ver"] == redis.__version__ + assert await r.client_setinfo("lib-name", "test") + assert await r.client_setinfo("lib-ver", "123") + info = await r.client_info() + assert info["lib-name"] == "test" + assert info["lib-ver"] == "123" + r2 = redis.asyncio.Redis(lib_name="test2", lib_version="1234") + info = await r2.client_info() + assert info["lib-name"] == "test2" + assert info["lib-ver"] == "1234" + r3 = redis.asyncio.Redis(lib_name=None, lib_version=None) + info = await r3.client_info() + assert info["lib-name"] == "" + assert info["lib-ver"] == "" + @skip_if_server_version_lt("2.6.9") @pytest.mark.onlynoncluster async def test_client_kill(self, r: redis.Redis, r2): diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 9a729392b8..d1aad796e7 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -125,9 +125,11 @@ async def test_can_run_concurrent_commands(r): assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) -async def test_connect_retry_on_timeout_error(): +async def test_connect_retry_on_timeout_error(connect_args): """Test that the _connect function is retried in case of a timeout""" - conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3)) + conn = Connection( + retry_on_timeout=True, retry=Retry(NoBackoff(), 3), **connect_args + ) origin_connect = conn._connect conn._connect = mock.AsyncMock() @@ -200,7 +202,7 @@ async def test_connection_parse_response_resume(r: redis.Redis): [_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser], ids=["AsyncRESP2Parser", "AsyncRESP3Parser", "AsyncHiredisParser"], ) -async def test_connection_disconect_race(parser_class): +async def test_connection_disconect_race(parser_class, connect_args): """ This test reproduces the case in issue #2349 where a connection is closed while the parser is reading to feed the @@ -215,10 +217,9 @@ async def test_connection_disconect_race(parser_class): if parser_class == _AsyncHiredisParser and not HIREDIS_AVAILABLE: pytest.skip("Hiredis not available") - args = {} - args["parser_class"] = parser_class + connect_args["parser_class"] = parser_class - conn = Connection(**args) + conn = Connection(**connect_args) cond = asyncio.Condition() # 0 == initial @@ -267,8 +268,16 @@ async def do_read(): async def open_connection(*args, **kwargs): return reader, writer + async def dummy_method(*args, **kwargs): + pass + + # get dummy stream objects for the connection with patch.object(asyncio, "open_connection", open_connection): - await conn.connect() + # disable the initial version handshake + with patch.multiple( + conn, send_command=dummy_method, read_response=dummy_method + ): + await conn.connect() vals = await asyncio.gather(do_read(), do_close()) assert vals == [b"Hello, World!", None] diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 239927f484..ae194db3a2 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2425,7 +2425,7 @@ def teardown(): user_client.hset("{cache}:0", "hkey", "hval") assert isinstance(r.acl_log(target_nodes=node), list) - assert len(r.acl_log(target_nodes=node)) == 2 + assert len(r.acl_log(target_nodes=node)) == 3 assert len(r.acl_log(count=1, target_nodes=node)) == 1 assert isinstance(r.acl_log(target_nodes=node)[0], dict) assert "client-info" in r.acl_log(count=1, target_nodes=node)[0] diff --git a/tests/test_commands.py b/tests/test_commands.py index 9540f7f20c..055aa3bf9f 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -396,7 +396,7 @@ def teardown(): user_client.hset("cache:0", "hkey", "hval") assert isinstance(r.acl_log(), list) - assert len(r.acl_log()) == 2 + assert len(r.acl_log()) == 3 assert len(r.acl_log(count=1)) == 1 assert isinstance(r.acl_log()[0], dict) expected = r.acl_log(count=1)[0] @@ -554,6 +554,26 @@ def test_client_setname(self, r): assert r.client_setname("redis_py_test") assert_resp_response(r, r.client_getname(), "redis_py_test", b"redis_py_test") + @skip_if_server_version_lt("7.2.0") + def test_client_setinfo(self, r: redis.Redis): + r.ping() + info = r.client_info() + assert info["lib-name"] == "redis-py" + assert info["lib-ver"] == redis.__version__ + assert r.client_setinfo("lib-name", "test") + assert r.client_setinfo("lib-ver", "123") + info = r.client_info() + assert info["lib-name"] == "test" + assert info["lib-ver"] == "123" + r2 = redis.Redis(lib_name="test2", lib_version="1234") + info = r2.client_info() + assert info["lib-name"] == "test2" + assert info["lib-ver"] == "1234" + r3 = redis.Redis(lib_name=None, lib_version=None) + info = r3.client_info() + assert info["lib-name"] == "" + assert info["lib-ver"] == "" + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.6.9") def test_client_kill(self, r, r2): @@ -5066,6 +5086,7 @@ def test_shutdown_with_params(self, r: redis.Redis): r.execute_command.assert_called_with("SHUTDOWN", "ABORT") @pytest.mark.replica + @pytest.mark.xfail(strict=False) @skip_if_server_version_lt("2.8.0") @skip_if_redis_enterprise() def test_sync(self, r): diff --git a/tests/test_lock.py b/tests/test_lock.py index b4b9b32917..b34f7f0159 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -101,11 +101,12 @@ def test_blocking_timeout(self, r): assert lock1.acquire(blocking=False) bt = 0.4 sleep = 0.05 + fudge_factor = 0.05 lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) start = time.monotonic() assert not lock2.acquire() # The elapsed duration should be less than the total blocking_timeout - assert bt > (time.monotonic() - start) > bt - sleep + assert (bt + fudge_factor) > (time.monotonic() - start) > bt - sleep lock1.release() def test_context_manager(self, r): @@ -119,11 +120,12 @@ def test_context_manager_blocking_timeout(self, r): with self.get_lock(r, "foo", blocking=False): bt = 0.4 sleep = 0.05 + fudge_factor = 0.05 lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) start = time.monotonic() assert not lock2.acquire() # The elapsed duration should be less than the total blocking_timeout - assert bt > (time.monotonic() - start) > bt - sleep + assert (bt + fudge_factor) > (time.monotonic() - start) > bt - sleep def test_context_manager_raises_when_locked_not_acquired(self, r): r.set("foo", "bar")