From c7a13ae3f2529ee3f1bd0b5f991de4d06c33a6db Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Tue, 9 Jan 2024 12:47:00 +0200 Subject: [PATCH] Support client side caching with RedisCluster (#3102) * sync * fix mock_node_resp * fix mock_node_resp_func * fix test_handling_cluster_failover_to_a_replica * fix test_handling_cluster_failover_to_a_replica * async cluster and cleanup tests * delete comment --- redis/asyncio/cluster.py | 46 ++++- redis/asyncio/connection.py | 4 + redis/cluster.py | 29 ++- tests/conftest.py | 2 +- tests/test_asyncio/conftest.py | 3 +- tests/test_asyncio/test_cache.py | 291 ++++++++++++++++------------- tests/test_asyncio/test_cluster.py | 2 + tests/test_cache.py | 286 ++++++++++++++++------------ tests/test_cluster.py | 3 + 9 files changed, 398 insertions(+), 268 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 6a1753ad19..486053e1cc 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -18,6 +18,12 @@ Union, ) +from redis._cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalCache, +) from redis._parsers import AsyncCommandsParser, Encoder from redis._parsers.helpers import ( _RedisCallbacks, @@ -267,6 +273,13 @@ def __init__( ssl_keyfile: Optional[str] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, + cache_enable: bool = False, + client_cache: Optional[_LocalCache] = None, + cache_max_size: int = 100, + cache_ttl: int = 0, + cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_blacklist: List[str] = DEFAULT_BLACKLIST, + cache_whitelist: List[str] = DEFAULT_WHITELIST, ) -> None: if db: raise RedisClusterException( @@ -310,6 +323,14 @@ def __init__( "socket_timeout": socket_timeout, "retry": retry, "protocol": protocol, + # Client cache related kwargs + "cache_enable": cache_enable, + "client_cache": client_cache, + "cache_max_size": cache_max_size, + "cache_ttl": cache_ttl, + "cache_eviction_policy": cache_eviction_policy, + "cache_blacklist": cache_blacklist, + "cache_whitelist": cache_whitelist, } if ssl: @@ -682,7 +703,6 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: :raises RedisClusterException: if target_nodes is not provided & the command can't be mapped to a slot """ - kwargs.pop("keys", None) # the keys are used only for client side caching command = args[0] target_nodes = [] target_nodes_specified = False @@ -1039,16 +1059,24 @@ async def parse_response( async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection connection = self.acquire_connection() + keys = kwargs.pop("keys", None) - # Execute command - await connection.send_packed_command(connection.pack_command(*args), False) - - # Read response - try: - return await self.parse_response(connection, args[0], **kwargs) - finally: - # Release connection + response_from_cache = await connection._get_from_local_cache(args) + if response_from_cache is not None: self._free.append(connection) + return response_from_cache + else: + # Execute command + await connection.send_packed_command(connection.pack_command(*args), False) + + # Read response + try: + response = await self.parse_response(connection, args[0], **kwargs) + connection._add_to_local_cache(args, response, keys) + return response + finally: + # Release connection + self._free.append(connection) async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 7f1c0b71e4..05a27879a6 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -227,6 +227,10 @@ def __init__( _cache = None self.client_cache = client_cache if client_cache is not None else _cache if self.client_cache is not None: + if self.protocol not in [3, "3"]: + raise RedisError( + "client caching is only supported with protocol version 3 or higher" + ) self.cache_blacklist = cache_blacklist self.cache_whitelist = cache_whitelist diff --git a/redis/cluster.py b/redis/cluster.py index 8032173e66..7bdf4c1951 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -167,6 +167,13 @@ def parse_cluster_myshardid(resp, **options): "ssl_password", "unix_socket_path", "username", + "cache_enable", + "client_cache", + "cache_max_size", + "cache_ttl", + "cache_eviction_policy", + "cache_blacklist", + "cache_whitelist", ) KWARGS_DISABLED_KEYS = ("host", "port") @@ -1060,7 +1067,6 @@ def execute_command(self, *args, **kwargs): list dict """ - kwargs.pop("keys", None) # the keys are used only for client side caching target_nodes_specified = False is_default_node = False target_nodes = None @@ -1119,6 +1125,7 @@ def _execute_command(self, target_node, *args, **kwargs): """ Send a command to a node in the cluster """ + keys = kwargs.pop("keys", None) command = args[0] redis_node = None connection = None @@ -1147,14 +1154,18 @@ def _execute_command(self, target_node, *args, **kwargs): connection.send_command("ASKING") redis_node.parse_response(connection, "ASKING", **kwargs) asking = False - - connection.send_command(*args) - response = redis_node.parse_response(connection, command, **kwargs) - if command in self.cluster_response_callbacks: - response = self.cluster_response_callbacks[command]( - response, **kwargs - ) - return response + response_from_cache = connection._get_from_local_cache(args) + if response_from_cache is not None: + return response_from_cache + else: + connection.send_command(*args) + response = redis_node.parse_response(connection, command, **kwargs) + if command in self.cluster_response_callbacks: + response = self.cluster_response_callbacks[command]( + response, **kwargs + ) + connection._add_to_local_cache(args, response, keys) + return response except AuthenticationError: raise except (ConnectionError, TimeoutError) as e: diff --git a/tests/conftest.py b/tests/conftest.py index e56b5f6aed..8786e2b9f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -275,7 +275,7 @@ def _get_client( redis_url = request.config.getoption("--redis-url") else: redis_url = from_url - if "protocol" not in redis_url: + if "protocol" not in redis_url and kwargs.get("protocol") is None: kwargs["protocol"] = request.config.getoption("--protocol") cluster_mode = REDIS_INFO["cluster_enabled"] diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index c79b706abc..c6afec5af6 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -69,10 +69,9 @@ async def client_factory( url: str = request.config.getoption("--redis-url"), cls=redis.Redis, flushdb=True, - protocol=request.config.getoption("--protocol"), **kwargs, ): - if "protocol" not in url: + if "protocol" not in url and kwargs.get("protocol") is None: kwargs["protocol"] = request.config.getoption("--protocol") cluster_mode = REDIS_INFO["cluster_enabled"] diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py index 92328b8391..098ede8d75 100644 --- a/tests/test_asyncio/test_cache.py +++ b/tests/test_asyncio/test_cache.py @@ -1,140 +1,177 @@ import time import pytest -import redis.asyncio as redis +import pytest_asyncio from redis._cache import _LocalCache from redis.utils import HIREDIS_AVAILABLE -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -async def test_get_from_cache(): - cache = _LocalCache() - r = redis.Redis(protocol=3, client_cache=cache) - r2 = redis.Redis(protocol=3) - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == b"barbar" - - await r.flushdb() - await r.aclose() - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -async def test_cache_max_size(): - cache = _LocalCache(max_size=3) - r = redis.Redis(client_cache=cache, protocol=3) - # add 3 keys to redis - await r.set("foo", "bar") - await r.set("foo2", "bar2") - await r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert await r.get("foo") == b"bar" - assert await r.get("foo2") == b"bar2" - assert await r.get("foo3") == b"bar3" - # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - await r.set("foo4", "bar4") - assert await r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - await r.flushdb() - await r.aclose() - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -async def test_cache_ttl(): - cache = _LocalCache(ttl=1) - r = redis.Redis(client_cache=cache, protocol=3) - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - await r.flushdb() - await r.aclose() - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -async def test_cache_lfu_eviction(): - cache = _LocalCache(max_size=3, eviction_policy="lfu") - r = redis.Redis(client_cache=cache, protocol=3) - # add 3 keys to redis - await r.set("foo", "bar") - await r.set("foo2", "bar2") - await r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert await r.get("foo") == b"bar" - assert await r.get("foo2") == b"bar2" - assert await r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - await r.set("foo4", "bar4") - assert await r.get("foo4") == b"bar4" - # test the eviction policy - assert len(cache.cache) == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None - - await r.flushdb() - await r.aclose() +@pytest_asyncio.fixture +async def r(request, create_redis): + cache = request.param.get("cache") + kwargs = request.param.get("kwargs", {}) + r = await create_redis(protocol=3, client_cache=cache, **kwargs) + yield r, cache @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -async def test_cache_decode_response(): - cache = _LocalCache() - r = redis.Redis(decode_responses=True, client_cache=cache, protocol=3) - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == "barbar" - - await r.flushdb() - await r.aclose() +class TestLocalCache: + @pytest.mark.onlynoncluster + @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) + async def test_get_from_cache(self, r, r2): + r, cache = r + # add key to redis + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + await r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await r.ping() + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == b"barbar" + + @pytest.mark.parametrize("r", [{"cache": _LocalCache(max_size=3)}], indirect=True) + async def test_cache_max_size(self, r): + r, cache = r + # add 3 keys to redis + await r.set("foo", "bar") + await r.set("foo2", "bar2") + await r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert await r.get("foo") == b"bar" + assert await r.get("foo2") == b"bar2" + assert await r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) == b"bar2" + assert cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + await r.set("foo4", "bar4") + assert await r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True) + async def test_cache_ttl(self, r): + r, cache = r + # add key to redis + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize( + "r", [{"cache": _LocalCache(max_size=3, eviction_policy="lfu")}], indirect=True + ) + async def test_cache_lfu_eviction(self, r): + r, cache = r + # add 3 keys to redis + await r.set("foo", "bar") + await r.set("foo2", "bar2") + await r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert await r.get("foo") == b"bar" + assert await r.get("foo2") == b"bar2" + assert await r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + await r.set("foo4", "bar4") + assert await r.get("foo4") == b"bar4" + # test the eviction policy + assert len(cache.cache) == 3 + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) is None + + @pytest.mark.onlynoncluster + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_cache_decode_response(self, r): + r, cache = r + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == "bar" + # get key from local cache + assert cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + await r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await r.ping() + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == "barbar" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"cache_blacklist": ["LLEN"]}}], + indirect=True, + ) + async def test_cache_blacklist(self, r): + r, cache = r + # add list to redis + await r.lpush("mylist", "foo", "bar", "baz") + assert await r.llen("mylist") == 3 + assert await r.lindex("mylist", 1) == b"bar" + assert cache.get(("LLEN", "mylist")) is None + assert cache.get(("LINDEX", "mylist", 1)) == b"bar" @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -async def test_cache_blacklist(): - cache = _LocalCache() - r = redis.Redis(client_cache=cache, cache_blacklist=["LLEN"], protocol=3) - # add list to redis - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.llen("mylist") == 3 - assert await r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) is None - assert cache.get(("LINDEX", "mylist", 1)) == b"bar" - - await r.flushdb() - await r.aclose() +@pytest.mark.onlycluster +class TestClusterLocalCache: + @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) + async def test_get_from_cache(self, r, r2): + r, cache = r + # add key to redis + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + await r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + node = r.get_node_from_key("foo") + await r.ping(target_nodes=node) + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == b"barbar" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_cache_decode_response(self, r): + r, cache = r + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == "bar" + # get key from local cache + assert cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + await r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + node = r.get_node_from_key("foo") + await r.ping(target_nodes=node) + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == "barbar" diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index e6cf2e4ce7..a57d32f5d2 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -178,6 +178,7 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.return_value = response + connection._get_from_local_cache.return_value = None while node._free: node._free.pop() node._free.append(connection) @@ -188,6 +189,7 @@ def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.side_effect = exc + connection._get_from_local_cache.return_value = None while node._free: node._free.pop() node._free.append(connection) diff --git a/tests/test_cache.py b/tests/test_cache.py index 85df8b1a22..570385a4b5 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -4,131 +4,177 @@ import redis from redis._cache import _LocalCache from redis.utils import HIREDIS_AVAILABLE +from tests.conftest import _get_client -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -def test_get_from_cache(): - cache = _LocalCache() - r = redis.Redis(protocol=3, client_cache=cache) - r2 = redis.Redis(protocol=3) - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" - - r.flushdb() - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -def test_cache_max_size(): - cache = _LocalCache(max_size=3) - r = redis.Redis(client_cache=cache, protocol=3) - # add 3 keys to redis - r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" - # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - r.flushdb() - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -def test_cache_ttl(): - cache = _LocalCache(ttl=1) - r = redis.Redis(client_cache=cache, protocol=3) - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - r.flushdb() - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -def test_cache_lfu_eviction(): - cache = _LocalCache(max_size=3, eviction_policy="lfu") - r = redis.Redis(client_cache=cache, protocol=3) - # add 3 keys to redis - r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # test the eviction policy - assert len(cache.cache) == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None - - r.flushdb() +@pytest.fixture() +def r(request): + cache = request.param.get("cache") + kwargs = request.param.get("kwargs", {}) + with _get_client( + redis.Redis, request, protocol=3, client_cache=cache, **kwargs + ) as client: + yield client, cache @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -def test_cache_decode_response(): - cache = _LocalCache() - r = redis.Redis(decode_responses=True, client_cache=cache, protocol=3) - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == "barbar" - - r.flushdb() +class TestLocalCache: + @pytest.mark.onlynoncluster + @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) + def test_get_from_cache(self, r, r2): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == b"barbar" + + @pytest.mark.parametrize("r", [{"cache": _LocalCache(max_size=3)}], indirect=True) + def test_cache_max_size(self, r): + r, cache = r + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) == b"bar2" + assert cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True) + def test_cache_ttl(self, r): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize( + "r", [{"cache": _LocalCache(max_size=3, eviction_policy="lfu")}], indirect=True + ) + def test_cache_lfu_eviction(self, r): + r, cache = r + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # test the eviction policy + assert len(cache.cache) == 3 + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) is None + + @pytest.mark.onlynoncluster + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_cache_decode_response(self, r): + r, cache = r + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == "bar" + # get key from local cache + assert cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == "barbar" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"cache_blacklist": ["LLEN"]}}], + indirect=True, + ) + def test_cache_blacklist(self, r): + r, cache = r + # add list to redis + r.lpush("mylist", "foo", "bar", "baz") + assert r.llen("mylist") == 3 + assert r.lindex("mylist", 1) == b"bar" + assert cache.get(("LLEN", "mylist")) is None + assert cache.get(("LINDEX", "mylist", 1)) == b"bar" @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -def test_cache_blacklist(): - cache = _LocalCache() - r = redis.Redis(client_cache=cache, cache_blacklist=["LLEN"], protocol=3) - # add list to redis - r.lpush("mylist", "foo", "bar", "baz") - assert r.llen("mylist") == 3 - assert r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) is None - assert cache.get(("LINDEX", "mylist", 1)) == b"bar" - - r.flushdb() +@pytest.mark.onlycluster +class TestClusterLocalCache: + @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) + def test_get_from_cache(self, r, r2): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + node = r.get_node_from_key("foo") + r.ping(target_nodes=node) + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == b"barbar" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_cache_decode_response(self, r): + r, cache = r + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == "bar" + # get key from local cache + assert cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + node = r.get_node_from_key("foo") + r.ping(target_nodes=node) + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == "barbar" diff --git a/tests/test_cluster.py b/tests/test_cluster.py index ae194db3a2..854b64c563 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -202,6 +202,7 @@ def cmd_init_mock(self, r): def mock_node_resp(node, response): connection = Mock() connection.read_response.return_value = response + connection._get_from_local_cache.return_value = None node.redis_connection.connection = connection return node @@ -209,6 +210,7 @@ def mock_node_resp(node, response): def mock_node_resp_func(node, func): connection = Mock() connection.read_response.side_effect = func + connection._get_from_local_cache.return_value = None node.redis_connection.connection = connection return node @@ -477,6 +479,7 @@ def mock_execute_command(*_args, **_kwargs): redis_mock_node.execute_command.side_effect = mock_execute_command # Mock response value for all other commands redis_mock_node.parse_response.return_value = "MOCK_OK" + redis_mock_node.connection._get_from_local_cache.return_value = None for node in r.get_nodes(): if node.port != primary.port: node.redis_connection = redis_mock_node