diff --git a/redis/cache.py b/redis/_cache.py similarity index 98% rename from redis/cache.py rename to redis/_cache.py index d920702339..4255afb7a4 100644 --- a/redis/cache.py +++ b/redis/_cache.py @@ -178,7 +178,11 @@ class _LocalCache: """ def __init__( - self, max_size: int, ttl: int, eviction_policy: EvictionPolicy, **kwargs + self, + max_size: int = 100, + ttl: int = 0, + eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, + **kwargs, ): self.max_size = max_size self.ttl = ttl diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 79689fcb5e..143d997757 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -25,6 +25,12 @@ cast, ) +from redis._cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalCache, +) from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -39,12 +45,6 @@ ) from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry -from redis.cache import ( - DEFAULT_BLACKLIST, - DEFAULT_EVICTION_POLICY, - DEFAULT_WHITELIST, - _LocalCache, -) from redis.client import ( EMPTY_RESPONSE, NEVER_DECODE, @@ -67,7 +67,7 @@ TimeoutError, WatchError, ) -from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT +from redis.typing import ChannelT, EncodableT, KeyT from redis.utils import ( HIREDIS_AVAILABLE, _set_info_logger, @@ -294,6 +294,13 @@ def __init__( "lib_version": lib_version, "redis_connect_func": redis_connect_func, "protocol": protocol, + "cache_enable": cache_enable, + "client_cache": client_cache, + "cache_max_size": cache_max_size, + "cache_ttl": cache_ttl, + "cache_eviction_policy": cache_eviction_policy, + "cache_blacklist": cache_blacklist, + "cache_whitelist": cache_whitelist, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -350,16 +357,6 @@ def __init__( # on a set of redis commands self._single_conn_lock = asyncio.Lock() - self.client_cache = client_cache - if cache_enable: - self.client_cache = _LocalCache( - cache_max_size, cache_ttl, cache_eviction_policy - ) - if self.client_cache is not None: - self.cache_blacklist = cache_blacklist - self.cache_whitelist = cache_whitelist - self.client_cache_initialized = False - def __repr__(self): return ( f"<{self.__class__.__module__}.{self.__class__.__name__}" @@ -374,10 +371,6 @@ async def initialize(self: _RedisT) -> _RedisT: async with self._single_conn_lock: if self.connection is None: self.connection = await self.connection_pool.get_connection("_") - if self.client_cache is not None: - self.connection._parser.set_invalidation_push_handler( - self._cache_invalidation_process - ) return self def set_response_callback(self, command: str, callback: ResponseCallbackT): @@ -596,8 +589,6 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: close_connection_pool is None and self.auto_close_connection_pool ): await self.connection_pool.disconnect() - if self.client_cache: - self.client_cache.flush() @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") async def close(self, close_connection_pool: Optional[bool] = None) -> None: @@ -626,89 +617,28 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): ): raise error - def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] - ) -> None: - """ - Invalidate (delete) all redis commands associated with a specific key. - `data` is a list of strings, where the first string is the invalidation message - and the second string is the list of keys to invalidate. - (if the list of keys is None, then all keys are invalidated) - """ - if data[1] is not None: - for key in data[1]: - self.client_cache.invalidate(str_if_bytes(key)) - else: - self.client_cache.flush() - - async def _get_from_local_cache(self, command: str): - """ - If the command is in the local cache, return the response - """ - if ( - self.client_cache is None - or command[0] in self.cache_blacklist - or command[0] not in self.cache_whitelist - ): - return None - while not self.connection._is_socket_empty(): - await self.connection.read_response(push_request=True) - return self.client_cache.get(command) - - def _add_to_local_cache( - self, command: Tuple[str], response: ResponseT, keys: List[KeysT] - ): - """ - Add the command and response to the local cache if the command - is allowed to be cached - """ - if ( - self.client_cache is not None - and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) - and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) - ): - self.client_cache.set(command, response, keys) - - def delete_from_local_cache(self, command: str): - """ - Delete the command from the local cache - """ - try: - self.client_cache.delete(command) - except AttributeError: - pass - # COMMAND EXECUTION AND PROTOCOL PARSING async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() command_name = args[0] keys = options.pop("keys", None) # keys are used only for client side caching - response_from_cache = await self._get_from_local_cache(args) + pool = self.connection_pool + conn = self.connection or await pool.get_connection(command_name, **options) + response_from_cache = await conn._get_from_local_cache(args) if response_from_cache is not None: return response_from_cache else: - pool = self.connection_pool - conn = self.connection or await pool.get_connection(command_name, **options) - if self.single_connection_client: await self._single_conn_lock.acquire() try: - if self.client_cache is not None and not self.client_cache_initialized: - await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, "CLIENT", *("CLIENT", "TRACKING", "ON") - ), - lambda error: self._disconnect_raise(conn, error), - ) - self.client_cache_initialized = True response = await conn.retry.call_with_retry( lambda: self._send_command_parse_response( conn, command_name, *args, **options ), lambda error: self._disconnect_raise(conn, error), ) - self._add_to_local_cache(args, response, keys) + conn._add_to_local_cache(args, response, keys) return response finally: if self.single_connection_client: diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index df2bd20f9f..7f1c0b71e4 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -47,9 +47,15 @@ ResponseError, TimeoutError, ) -from redis.typing import EncodableT +from redis.typing import EncodableT, KeysT, ResponseT from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes +from .._cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalCache, +) from .._parsers import ( BaseParser, Encoder, @@ -114,6 +120,9 @@ class AbstractConnection: "encoder", "ssl_context", "protocol", + "client_cache", + "cache_blacklist", + "cache_whitelist", "_reader", "_writer", "_parser", @@ -148,6 +157,13 @@ def __init__( encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + cache_enable: bool = False, + client_cache: Optional[_LocalCache] = None, + cache_max_size: int = 100, + cache_ttl: int = 0, + cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_blacklist: List[str] = DEFAULT_BLACKLIST, + cache_whitelist: List[str] = DEFAULT_WHITELIST, ): if (username or password) and credential_provider is not None: raise DataError( @@ -205,6 +221,14 @@ def __init__( if p < 2 or p > 3: raise ConnectionError("protocol must be either 2 or 3") self.protocol = protocol + if cache_enable: + _cache = _LocalCache(cache_max_size, cache_ttl, cache_eviction_policy) + else: + _cache = None + self.client_cache = client_cache if client_cache is not None else _cache + if self.client_cache is not None: + self.cache_blacklist = cache_blacklist + self.cache_whitelist = cache_whitelist def __del__(self, _warnings: Any = warnings): # For some reason, the individual streams don't get properly garbage @@ -395,6 +419,11 @@ async def on_connect(self) -> None: # if a database is specified, switch to it. Also pipeline this if self.db: await self.send_command("SELECT", self.db) + # if client caching is enabled, start tracking + if self.client_cache: + await self.send_command("CLIENT", "TRACKING", "ON") + await self.read_response() + self._parser.set_invalidation_push_handler(self._cache_invalidation_process) # read responses from pipeline for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): @@ -429,6 +458,9 @@ async def disconnect(self, nowait: bool = False) -> None: raise TimeoutError( f"Timed out closing connection after {self.socket_connect_timeout}" ) from None + finally: + if self.client_cache: + self.client_cache.flush() async def _send_ping(self): """Send PING, expect PONG in return""" @@ -646,10 +678,62 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] output.append(SYM_EMPTY.join(pieces)) return output - def _is_socket_empty(self): + def _socket_is_empty(self): """Check if the socket is empty""" return not self._reader.at_eof() + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ + if data[1] is not None: + self.client_cache.flush() + else: + for key in data[1]: + self.client_cache.invalidate(str_if_bytes(key)) + + async def _get_from_local_cache(self, command: str): + """ + If the command is in the local cache, return the response + """ + if ( + self.client_cache is None + or command[0] in self.cache_blacklist + or command[0] not in self.cache_whitelist + ): + return None + while not self._socket_is_empty(): + await self.read_response(push_request=True) + return self.client_cache.get(command) + + def _add_to_local_cache( + self, command: Tuple[str], response: ResponseT, keys: List[KeysT] + ): + """ + Add the command and response to the local cache if the command + is allowed to be cached + """ + if ( + self.client_cache is not None + and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) + and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + ): + self.client_cache.set(command, response, keys) + + def delete_from_local_cache(self, command: str): + """ + Delete the command from the local cache + """ + try: + self.client_cache.delete(command) + except AttributeError: + pass + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" diff --git a/redis/client.py b/redis/client.py index 7f2c8d290d..d685145339 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,8 +4,14 @@ import time import warnings from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union +from redis._cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalCache, +) from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( _RedisCallbacks, @@ -13,12 +19,6 @@ _RedisCallbacksRESP3, bool_ok, ) -from redis.cache import ( - DEFAULT_BLACKLIST, - DEFAULT_EVICTION_POLICY, - DEFAULT_WHITELIST, - _LocalCache, -) from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -38,7 +38,6 @@ ) from redis.lock import Lock from redis.retry import Retry -from redis.typing import KeysT, ResponseT from redis.utils import ( HIREDIS_AVAILABLE, _set_info_logger, @@ -268,6 +267,13 @@ def __init__( "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, "protocol": protocol, + "cache_enable": cache_enable, + "client_cache": client_cache, + "cache_max_size": cache_max_size, + "cache_ttl": cache_ttl, + "cache_eviction_policy": cache_eviction_policy, + "cache_blacklist": cache_blacklist, + "cache_whitelist": cache_whitelist, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -324,19 +330,6 @@ def __init__( else: self.response_callbacks.update(_RedisCallbacksRESP2) - self.client_cache = client_cache - if cache_enable: - self.client_cache = _LocalCache( - cache_max_size, cache_ttl, cache_eviction_policy - ) - if self.client_cache is not None: - self.cache_blacklist = cache_blacklist - self.cache_whitelist = cache_whitelist - self.client_tracking_on() - self.connection._parser.set_invalidation_push_handler( - self._cache_invalidation_process - ) - def __repr__(self) -> str: return ( f"<{type(self).__module__}.{type(self).__name__}" @@ -362,21 +355,6 @@ def set_response_callback(self, command: str, callback: Callable) -> None: """Set a custom Response Callback""" self.response_callbacks[command] = callback - def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] - ) -> None: - """ - Invalidate (delete) all redis commands associated with a specific key. - `data` is a list of strings, where the first string is the invalidation message - and the second string is the list of keys to invalidate. - (if the list of keys is None, then all keys are invalidated) - """ - if data[1] is not None: - for key in data[1]: - self.client_cache.invalidate(str_if_bytes(key)) - else: - self.client_cache.flush() - def load_external_module(self, funcname, func) -> None: """ This function can be used to add externally defined redis modules, @@ -549,8 +527,6 @@ def close(self): if self.auto_close_connection_pool: self.connection_pool.disconnect() - if self.client_cache: - self.client_cache.flush() def _send_command_parse_response(self, conn, command_name, *args, **options): """ @@ -572,55 +548,17 @@ def _disconnect_raise(self, conn, error): ): raise error - def _get_from_local_cache(self, command: str): - """ - If the command is in the local cache, return the response - """ - if ( - self.client_cache is None - or command[0] in self.cache_blacklist - or command[0] not in self.cache_whitelist - ): - return None - while not self.connection._is_socket_empty(): - self.connection.read_response(push_request=True) - return self.client_cache.get(command) - - def _add_to_local_cache( - self, command: Tuple[str], response: ResponseT, keys: List[KeysT] - ): - """ - Add the command and response to the local cache if the command - is allowed to be cached - """ - if ( - self.client_cache is not None - and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) - and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) - ): - self.client_cache.set(command, response, keys) - - def delete_from_local_cache(self, command: str): - """ - Delete the command from the local cache - """ - try: - self.client_cache.delete(command) - except AttributeError: - pass - # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): """Execute a command and return a parsed response""" command_name = args[0] keys = options.pop("keys", None) - response_from_cache = self._get_from_local_cache(args) + pool = self.connection_pool + conn = self.connection or pool.get_connection(command_name, **options) + response_from_cache = conn._get_from_local_cache(args) if response_from_cache is not None: return response_from_cache else: - pool = self.connection_pool - conn = self.connection or pool.get_connection(command_name, **options) - try: response = conn.retry.call_with_retry( lambda: self._send_command_parse_response( @@ -628,7 +566,7 @@ def execute_command(self, *args, **options): ), lambda error: self._disconnect_raise(conn, error), ) - self._add_to_local_cache(args, response, keys) + conn._add_to_local_cache(args, response, keys) return response finally: if not self.connection: diff --git a/redis/connection.py b/redis/connection.py index 35a4ff4a37..a09fb3949c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -10,9 +10,15 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, List, Optional, Type, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union from urllib.parse import parse_qs, unquote, urlparse +from ._cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalCache, +) from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider @@ -27,6 +33,7 @@ TimeoutError, ) from .retry import Retry +from .typing import KeysT, ResponseT from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, @@ -150,6 +157,13 @@ def __init__( credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, + cache_enable: bool = False, + client_cache: Optional[_LocalCache] = None, + cache_max_size: int = 100, + cache_ttl: int = 0, + cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_blacklist: List[str] = DEFAULT_BLACKLIST, + cache_whitelist: List[str] = DEFAULT_WHITELIST, ): """ Initialize a new Connection. @@ -215,6 +229,18 @@ def __init__( # p = DEFAULT_RESP_VERSION self.protocol = p self._command_packer = self._construct_command_packer(command_packer) + if cache_enable: + _cache = _LocalCache(cache_max_size, cache_ttl, cache_eviction_policy) + else: + _cache = None + self.client_cache = client_cache if client_cache is not None else _cache + if self.client_cache is not None: + if self.protocol not in [3, "3"]: + raise RedisError( + "client caching is only supported with protocol version 3 or higher" + ) + self.cache_blacklist = cache_blacklist + self.cache_whitelist = cache_whitelist def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -406,6 +432,12 @@ def on_connect(self): if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Invalid Database") + # if client caching is enabled, start tracking + if self.client_cache: + self.send_command("CLIENT", "TRACKING", "ON") + self.read_response() + self._parser.set_invalidation_push_handler(self._cache_invalidation_process) + def disconnect(self, *args): "Disconnects from the Redis server" self._parser.on_disconnect() @@ -426,6 +458,9 @@ def disconnect(self, *args): except OSError: pass + if self.client_cache: + self.client_cache.flush() + def _send_ping(self): """Send PING, expect PONG in return""" self.send_command("PING", check_health=False) @@ -573,11 +608,63 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def _is_socket_empty(self): + def _socket_is_empty(self): """Check if the socket is empty""" r, _, _ = select.select([self._sock], [], [], 0) return not bool(r) + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ + if data[1] is None: + self.client_cache.flush() + else: + for key in data[1]: + self.client_cache.invalidate(str_if_bytes(key)) + + def _get_from_local_cache(self, command: str): + """ + If the command is in the local cache, return the response + """ + if ( + self.client_cache is None + or command[0] in self.cache_blacklist + or command[0] not in self.cache_whitelist + ): + return None + while not self._socket_is_empty(): + self.read_response(push_request=True) + return self.client_cache.get(command) + + def _add_to_local_cache( + self, command: Tuple[str], response: ResponseT, keys: List[KeysT] + ): + """ + Add the command and response to the local cache if the command + is allowed to be cached + """ + if ( + self.client_cache is not None + and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) + and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + ): + self.client_cache.set(command, response, keys) + + def delete_from_local_cache(self, command: str): + """ + Delete the command from the local cache + """ + try: + self.client_cache.delete(command) + except AttributeError: + pass + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" diff --git a/tests/conftest.py b/tests/conftest.py index bad9f43e42..e56b5f6aed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -356,6 +356,7 @@ def _gen_cluster_mock_resp(r, response): connection = Mock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response + connection._get_from_local_cache.return_value = None with mock.patch.object(r, "connection", connection): yield r diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 5d9e0b4f2e..c79b706abc 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -141,6 +141,7 @@ def _gen_cluster_mock_resp(r, response): connection = mock.AsyncMock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response + connection._get_from_local_cache.return_value = None with mock.patch.object(r, "connection", connection): yield r diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py index c837acfed1..92328b8391 100644 --- a/tests/test_asyncio/test_cache.py +++ b/tests/test_asyncio/test_cache.py @@ -2,36 +2,38 @@ import pytest import redis.asyncio as redis +from redis._cache import _LocalCache from redis.utils import HIREDIS_AVAILABLE @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") async def test_get_from_cache(): - r = redis.Redis(cache_enable=True, single_connection_client=True, protocol=3) + cache = _LocalCache() + r = redis.Redis(protocol=3, client_cache=cache) r2 = redis.Redis(protocol=3) # add key to redis await r.set("foo", "bar") # get key from redis and save in local cache assert await r.get("foo") == b"bar" # get key from local cache - assert r.client_cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" # change key in redis (cause invalidation) await r2.set("foo", "barbar") # send any command to redis (process invalidation in background) await r.ping() # the command is not in the local cache anymore - assert r.client_cache.get(("GET", "foo")) is None + assert cache.get(("GET", "foo")) is None # get key from redis assert await r.get("foo") == b"barbar" + await r.flushdb() await r.aclose() @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") async def test_cache_max_size(): - r = redis.Redis( - cache_enable=True, cache_max_size=3, single_connection_client=True, protocol=3 - ) + cache = _LocalCache(max_size=3) + r = redis.Redis(client_cache=cache, protocol=3) # add 3 keys to redis await r.set("foo", "bar") await r.set("foo2", "bar2") @@ -41,46 +43,42 @@ async def test_cache_max_size(): assert await r.get("foo2") == b"bar2" assert await r.get("foo3") == b"bar3" # get the 3 keys from local cache - assert r.client_cache.get(("GET", "foo")) == b"bar" - assert r.client_cache.get(("GET", "foo2")) == b"bar2" - assert r.client_cache.get(("GET", "foo3")) == b"bar3" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) == b"bar2" + assert cache.get(("GET", "foo3")) == b"bar3" # add 1 more key to redis (exceed the max size) await r.set("foo4", "bar4") assert await r.get("foo4") == b"bar4" # the first key is not in the local cache anymore - assert r.client_cache.get(("GET", "foo")) is None + assert cache.get(("GET", "foo")) is None + await r.flushdb() await r.aclose() @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") async def test_cache_ttl(): - r = redis.Redis( - cache_enable=True, cache_ttl=1, single_connection_client=True, protocol=3 - ) + cache = _LocalCache(ttl=1) + r = redis.Redis(client_cache=cache, protocol=3) # add key to redis await r.set("foo", "bar") # get key from redis and save in local cache assert await r.get("foo") == b"bar" # get key from local cache - assert r.client_cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" # wait for the key to expire time.sleep(1) # the key is not in the local cache anymore - assert r.client_cache.get(("GET", "foo")) is None + assert cache.get(("GET", "foo")) is None + await r.flushdb() await r.aclose() @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") async def test_cache_lfu_eviction(): - r = redis.Redis( - cache_enable=True, - cache_max_size=3, - cache_eviction_policy="lfu", - single_connection_client=True, - protocol=3, - ) + cache = _LocalCache(max_size=3, eviction_policy="lfu") + r = redis.Redis(client_cache=cache, protocol=3) # add 3 keys to redis await r.set("foo", "bar") await r.set("foo2", "bar2") @@ -90,40 +88,53 @@ async def test_cache_lfu_eviction(): assert await r.get("foo2") == b"bar2" assert await r.get("foo3") == b"bar3" # change the order of the keys in the cache - assert r.client_cache.get(("GET", "foo")) == b"bar" - assert r.client_cache.get(("GET", "foo")) == b"bar" - assert r.client_cache.get(("GET", "foo3")) == b"bar3" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo3")) == b"bar3" # add 1 more key to redis (exceed the max size) await r.set("foo4", "bar4") assert await r.get("foo4") == b"bar4" # test the eviction policy - assert len(r.client_cache.cache) == 3 - assert r.client_cache.get(("GET", "foo")) == b"bar" - assert r.client_cache.get(("GET", "foo2")) is None + assert len(cache.cache) == 3 + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) is None + await r.flushdb() await r.aclose() @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") async def test_cache_decode_response(): - r = redis.Redis( - decode_responses=True, - cache_enable=True, - single_connection_client=True, - protocol=3, - ) + cache = _LocalCache() + r = redis.Redis(decode_responses=True, client_cache=cache, protocol=3) await r.set("foo", "bar") # get key from redis and save in local cache assert await r.get("foo") == "bar" # get key from local cache - assert r.client_cache.get(("GET", "foo")) == "bar" + assert cache.get(("GET", "foo")) == "bar" # change key in redis (cause invalidation) await r.set("foo", "barbar") # send any command to redis (process invalidation in background) await r.ping() # the command is not in the local cache anymore - assert r.client_cache.get(("GET", "foo")) is None + assert cache.get(("GET", "foo")) is None # get key from redis assert await r.get("foo") == "barbar" + await r.flushdb() + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_blacklist(): + cache = _LocalCache() + r = redis.Redis(client_cache=cache, cache_blacklist=["LLEN"], protocol=3) + # add list to redis + await r.lpush("mylist", "foo", "bar", "baz") + assert await r.llen("mylist") == 3 + assert await r.lindex("mylist", 1) == b"bar" + assert cache.get(("LLEN", "mylist")) is None + assert cache.get(("LINDEX", "mylist", 1)) == b"bar" + + await r.flushdb() await r.aclose() diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 55a1c3a2f6..4ff3808602 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -68,8 +68,9 @@ async def call_with_retry(self, _, __): in_use = False return "foo" - mock_conn = mock.MagicMock() + mock_conn = mock.AsyncMock(spec=Connection) mock_conn.retry = Retry_() + mock_conn._get_from_local_cache.return_value = None async def get_conn(_): # Validate only one client is created in single-client mode when diff --git a/tests/test_cache.py b/tests/test_cache.py index 45621fe77e..85df8b1a22 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -2,34 +2,37 @@ import pytest import redis +from redis._cache import _LocalCache from redis.utils import HIREDIS_AVAILABLE @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_get_from_cache(): - r = redis.Redis(cache_enable=True, single_connection_client=True, protocol=3) + cache = _LocalCache() + r = redis.Redis(protocol=3, client_cache=cache) r2 = redis.Redis(protocol=3) # add key to redis r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert r.client_cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" # change key in redis (cause invalidation) r2.set("foo", "barbar") # send any command to redis (process invalidation in background) r.ping() # the command is not in the local cache anymore - assert r.client_cache.get(("GET", "foo")) is None + assert cache.get(("GET", "foo")) is None # get key from redis assert r.get("foo") == b"barbar" + r.flushdb() + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_cache_max_size(): - r = redis.Redis( - cache_enable=True, cache_max_size=3, single_connection_client=True, protocol=3 - ) + cache = _LocalCache(max_size=3) + r = redis.Redis(client_cache=cache, protocol=3) # add 3 keys to redis r.set("foo", "bar") r.set("foo2", "bar2") @@ -39,42 +42,40 @@ def test_cache_max_size(): assert r.get("foo2") == b"bar2" assert r.get("foo3") == b"bar3" # get the 3 keys from local cache - assert r.client_cache.get(("GET", "foo")) == b"bar" - assert r.client_cache.get(("GET", "foo2")) == b"bar2" - assert r.client_cache.get(("GET", "foo3")) == b"bar3" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) == b"bar2" + assert cache.get(("GET", "foo3")) == b"bar3" # add 1 more key to redis (exceed the max size) r.set("foo4", "bar4") assert r.get("foo4") == b"bar4" # the first key is not in the local cache anymore - assert r.client_cache.get(("GET", "foo")) is None + assert cache.get(("GET", "foo")) is None + + r.flushdb() @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_cache_ttl(): - r = redis.Redis( - cache_enable=True, cache_ttl=1, single_connection_client=True, protocol=3 - ) + cache = _LocalCache(ttl=1) + r = redis.Redis(client_cache=cache, protocol=3) # add key to redis r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert r.client_cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" # wait for the key to expire time.sleep(1) # the key is not in the local cache anymore - assert r.client_cache.get(("GET", "foo")) is None + assert cache.get(("GET", "foo")) is None + + r.flushdb() @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_cache_lfu_eviction(): - r = redis.Redis( - cache_enable=True, - cache_max_size=3, - cache_eviction_policy="lfu", - single_connection_client=True, - protocol=3, - ) + cache = _LocalCache(max_size=3, eviction_policy="lfu") + r = redis.Redis(client_cache=cache, protocol=3) # add 3 keys to redis r.set("foo", "bar") r.set("foo2", "bar2") @@ -84,36 +85,50 @@ def test_cache_lfu_eviction(): assert r.get("foo2") == b"bar2" assert r.get("foo3") == b"bar3" # change the order of the keys in the cache - assert r.client_cache.get(("GET", "foo")) == b"bar" - assert r.client_cache.get(("GET", "foo")) == b"bar" - assert r.client_cache.get(("GET", "foo3")) == b"bar3" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo3")) == b"bar3" # add 1 more key to redis (exceed the max size) r.set("foo4", "bar4") assert r.get("foo4") == b"bar4" # test the eviction policy - assert len(r.client_cache.cache) == 3 - assert r.client_cache.get(("GET", "foo")) == b"bar" - assert r.client_cache.get(("GET", "foo2")) is None + assert len(cache.cache) == 3 + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) is None + + r.flushdb() @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_cache_decode_response(): - r = redis.Redis( - decode_responses=True, - cache_enable=True, - single_connection_client=True, - protocol=3, - ) + cache = _LocalCache() + r = redis.Redis(decode_responses=True, client_cache=cache, protocol=3) r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == "bar" # get key from local cache - assert r.client_cache.get(("GET", "foo")) == "bar" + assert cache.get(("GET", "foo")) == "bar" # change key in redis (cause invalidation) r.set("foo", "barbar") # send any command to redis (process invalidation in background) r.ping() # the command is not in the local cache anymore - assert r.client_cache.get(("GET", "foo")) is None + assert cache.get(("GET", "foo")) is None # get key from redis assert r.get("foo") == "barbar" + + r.flushdb() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_blacklist(): + cache = _LocalCache() + r = redis.Redis(client_cache=cache, cache_blacklist=["LLEN"], protocol=3) + # add list to redis + r.lpush("mylist", "foo", "bar", "baz") + assert r.llen("mylist") == 3 + assert r.lindex("mylist", 1) == b"bar" + assert cache.get(("LLEN", "mylist")) is None + assert cache.get(("LINDEX", "mylist", 1)) == b"bar" + + r.flushdb()