From cd27dc3333523991cfacf06fe132969507071d86 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 13 Nov 2023 03:00:51 +0200 Subject: [PATCH 1/7] CSC --- redis/__init__.py | 2 + redis/cache.py | 100 ++++++++++++++++++++++++++++++ redis/client.py | 68 +++++++++++++++++---- redis/utils.py | 152 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 309 insertions(+), 13 deletions(-) create mode 100644 redis/cache.py diff --git a/redis/__init__.py b/redis/__init__.py index 495d2d99bb..9892068e8e 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -2,6 +2,7 @@ from redis import asyncio # noqa from redis.backoff import default_backoff +from redis.cache import _Cache from redis.client import Redis, StrictRedis from redis.cluster import RedisCluster from redis.connection import ( @@ -61,6 +62,7 @@ def int_or_str(value): VERSION = tuple([99, 99, 99]) __all__ = [ + "_Cache", "AuthenticationError", "AuthenticationWrongNumberOfArgsError", "BlockingConnectionPool", diff --git a/redis/cache.py b/redis/cache.py new file mode 100644 index 0000000000..c0975f7873 --- /dev/null +++ b/redis/cache.py @@ -0,0 +1,100 @@ +import random +import time +from collections import OrderedDict, defaultdict + + +class _Cache: + def __init__(self, max_size: int, ttl: int, eviction_policy: str, **kwargs): + self.max_size = max_size + self.ttl = ttl + self.eviction_policy = eviction_policy + self.cache = OrderedDict() + self.key_commands_map = defaultdict(set) + self.commands_ttl_list = [] + + def set(self, command, response): + keys_in_command = self.get_keys_from_command(command) + if len(self.cache) >= self.max_size: + self._evict() + self.cache[command] = { + "response": response, + "keys": keys_in_command, + "created_time": time.monotonic(), + } + if self.eviction_policy == "lfu": + self.cache[command]["access_count"] = 0 + self._update_key_commands_map(keys_in_command, command) + self.commands_ttl_list.append(command) + + def get(self, command): + if command in self.cache: + if self._is_expired(command): + del self.cache[command] + keys_in_command = self.cache[command]["keys"] + self._del_key_commands_map(keys_in_command, command) + return None + self._update_access(command) + return self.cache[command]["response"] + return None + + def delete(self, command): + if command in self.cache: + keys_in_command = self.cache[command]["keys"] + self._del_key_commands_map(keys_in_command, command) + self.commands_ttl_list.remove(command) + del self.cache[command] + + def delete_many(self, commands): + pass + + def flush(self): + self.cache.clear() + self.key_commands_map.clear() + self.commands_ttl_list = [] + + def _is_expired(self, command): + if self.ttl == 0: + return False + return time.monotonic() - self.cache[command]["created_time"] > self.ttl + + def _update_access(self, command): + if self.eviction_policy == "lru": + self.cache.move_to_end(command) + elif self.eviction_policy == "lfu": + self.cache[command]["access_count"] = ( + self.cache.get(command, {}).get("access_count", 0) + 1 + ) + self.cache.move_to_end(command) + elif self.eviction_policy == "random": + pass # Random eviction doesn't require updates + + def _evict(self): + if self._is_expired(self.commands_ttl_list[0]): + self.delete(self.commands_ttl_list[0]) + elif self.eviction_policy == "lru": + self.cache.popitem(last=False) + elif self.eviction_policy == "lfu": + min_access_command = min( + self.cache, key=lambda k: self.cache[k].get("access_count", 0) + ) + self.cache.pop(min_access_command) + elif self.eviction_policy == "random": + random_command = random.choice(list(self.cache.keys())) + self.cache.pop(random_command) + + def _update_key_commands_map(self, keys, command): + for key in keys: + self.key_commands_map[key].add(command) + + def _del_key_commands_map(self, keys, command): + for key in keys: + self.key_commands_map[key].remove(command) + + def invalidate(self, key): + if key in self.key_commands_map: + for command in self.key_commands_map[key]: + self.delete(command) + + def get_keys_from_command(self, command): + # Implement your function to extract keys from a Redis command here + pass diff --git a/redis/client.py b/redis/client.py index cb91c7a088..38bc7dee57 100755 --- a/redis/client.py +++ b/redis/client.py @@ -13,6 +13,7 @@ _RedisCallbacksRESP3, bool_ok, ) +from redis.cache import _Cache from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -33,6 +34,8 @@ from redis.lock import Lock from redis.retry import Retry from redis.utils import ( + DEFAULT_BLACKLIST, + DEFAULT_WHITELIST, HIREDIS_AVAILABLE, _set_info_logger, get_lib_version, @@ -203,6 +206,13 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + client_caching: bool = False, + client_cache: Optional[_Cache] = None, + cache_max_size: int = 100, + cache_ttl: int = 0, + cache_eviction_policy: str = "lru", + cache_blacklist: List[str] = DEFAULT_BLACKLIST, + cache_whitelist: List[str] = DEFAULT_WHITELIST, ) -> None: """ Initialize a new Redis client. @@ -310,6 +320,12 @@ def __init__( else: self.response_callbacks.update(_RedisCallbacksRESP2) + self.client_cache = client_cache + self.cache_blacklist = cache_blacklist + self.cache_whitelist = cache_whitelist + if client_caching: + self.client_cache = _Cache(cache_max_size, cache_ttl, cache_eviction_policy) + def __repr__(self) -> str: return f"{type(self).__name__}<{repr(self.connection_pool)}>" @@ -525,23 +541,49 @@ def _disconnect_raise(self, conn, error): ): raise error + def get_from_local_cache(self, command): + if ( + self.client_cache is None + or command[0] in self.cache_blacklist + or command[0] not in self.cache_whitelist + ): + return None + return self.client_cache.get(command) + + def add_to_local_cache(self, command, response): + if ( + self.client_cache is not None + and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) + and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + ): + self.client_cache.set(command, response) + + def delete_from_local_cache(self, command): + if self.client_cache is not None: + self.client_cache.delete(command) + # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): """Execute a command and return a parsed response""" - pool = self.connection_pool - command_name = args[0] - conn = self.connection or pool.get_connection(command_name, **options) + response_from_cache = self.get_from_local_cache(args) + if response_from_cache is not None: + return response_from_cache + else: + pool = self.connection_pool + command_name = args[0] + conn = self.connection or pool.get_connection(command_name, **options) - try: - return conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - finally: - if not self.connection: - pool.release(conn) + try: + response = conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + self.add_to_local_cache(args, response) + finally: + if not self.connection: + pool.release(conn) def parse_response(self, connection, command_name, **options): """Parses a response from the Redis server""" diff --git a/redis/utils.py b/redis/utils.py index 01fdfed7a2..271af61ee9 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -145,3 +145,155 @@ def get_lib_version(): except metadata.PackageNotFoundError: libver = "99.99.99" return libver + + +DEFAULT_BLACKLIST = [ + "FT.AGGREGATE", + "FT.ALIASADD", + "FT.ALIASDEL", + "FT.ALIASUPDATE", + "FT.CURSOR", + "FT.EXPLAIN", + "FT.EXPLAINCLI", + "FT.GET", + "FT.INFO", + "FT.MGET", + "FT.PROFILE", + "FT.SEARCH", + "FT.SPELLCHECK", + "FT.SUGGET", + "FT.SUGLEN", + "FT.SYNDUMP", + "FT.TAGVALS", + "FT._ALIASADDIFNX", + "BF.CARD", + "BF.DEBUG", + "BF.EXISTS", + "BF.INFO", + "BF.MEXISTS", + "BF.SCANDUMP", + "CF.COMPACT", + "CF.COUNT", + "CF.DEBUG", + "CF.EXISTS", + "CF.INFO", + "CF.MEXISTS", + "CF.SCANDUMP", + "CMS.INFO", + "CMS.QUERY", + "EXPIRETIME", + "HRANDFIELD", + "JSON.DEBUG", + "PEXPIRETIME", + "PFCOUNT", + "PTTL", + "SRANDMEMBER", + "TDIGEST.BYRANK", + "TDIGEST.BYREVRANK", + "TDIGEST.CDF", + "TDIGEST.INFO", + "TDIGEST.MAX", + "TDIGEST.MIN", + "TDIGEST.QUANTILE", + "TDIGEST.RANK", + "TDIGEST.REVRANK", + "TDIGEST.TRIMMED_MEAN", + "TOPK.INFO", + "TOPK.LIST", + "TOPK.QUERY", + "TTL", +] + + +DEFAULT_WHITELIST = [ + "BITCOUNT", + "BITFIELD_RO", + "BITPOS", + "DBSIZE", + "DUMP", + "EVALSHA_RO", + "EVAL_RO", + "EXISTS", + "EXPIRETIME", + "FCALL_RO", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUSBYMEMBER_RO", + "GEORADIUS_RO", + "GEOSEARCH", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HRANDFIELD", + "HSCAN", + "HSTRLEN", + "HVALS", + "KEYS", + "LCS", + "LINDEX", + "LLEN", + "LOLWUT", + "LPOS", + "LRANGE", + "MEMORY USAGE", + "MGET", + "OBJECT ENCODING", + "OBJECT FREQ", + "OBJECT IDLETIME", + "OBJECT REFCOUNT", + "PEXPIRETIME", + "PFCOUNT", + "PTTL", + "RANDOMKEY", + "SCAN", + "SCARD", + "SDIFF", + "SINTER", + "SINTERCARD", + "SISMEMBER", + "SMEMBERS", + "SMISMEMBER", + "SORT_RO", + "SRANDMEMBER", + "SSCAN", + "STRLEN", + "SUBSTR", + "SUNION", + "TOUCH", + "TTL", + "TYPE", + "XINFO CONSUMERS", + "XINFO GROUPS", + "XINFO STREAM", + "XLEN", + "XPENDING", + "XRANGE", + "XREAD", + "XREVRANGE", + "ZCARD", + "ZCOUNT", + "ZDIFF", + "ZINTER", + "ZINTERCARD", + "ZLEXCOUNT", + "ZMSCORE", + "ZRANDMEMBER", + "ZRANGE", + "ZRANGEBYLEX", + "ZRANGEBYSCORE", + "ZRANK", + "ZREVRANGE", + "ZREVRANGEBYLEX", + "ZREVRANGEBYSCORE", + "ZREVRANK", + "ZSCAN", + "ZSCORE", + "ZUNION", +] From 6c69ac4e358a09f2b85c89de0ec071a71f15adce Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 13 Nov 2023 15:50:30 +0200 Subject: [PATCH 2/7] get keys from command --- redis/cache.py | 7 +- redis/client.py | 9 +- redis/commands/core.py | 115 ++++++++++++++------------ redis/commands/json/commands.py | 18 ++-- redis/commands/timeseries/commands.py | 8 +- redis/utils.py | 45 ++++------ 6 files changed, 98 insertions(+), 104 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index c0975f7873..2a048e9bda 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -12,8 +12,7 @@ def __init__(self, max_size: int, ttl: int, eviction_policy: str, **kwargs): self.key_commands_map = defaultdict(set) self.commands_ttl_list = [] - def set(self, command, response): - keys_in_command = self.get_keys_from_command(command) + def set(self, command, response, keys_in_command): if len(self.cache) >= self.max_size: self._evict() self.cache[command] = { @@ -94,7 +93,3 @@ def invalidate(self, key): if key in self.key_commands_map: for command in self.key_commands_map[key]: self.delete(command) - - def get_keys_from_command(self, command): - # Implement your function to extract keys from a Redis command here - pass diff --git a/redis/client.py b/redis/client.py index 38bc7dee57..55bda19b48 100755 --- a/redis/client.py +++ b/redis/client.py @@ -550,13 +550,13 @@ def get_from_local_cache(self, command): return None return self.client_cache.get(command) - def add_to_local_cache(self, command, response): + def add_to_local_cache(self, command, response, keys): if ( self.client_cache is not None and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) ): - self.client_cache.set(command, response) + self.client_cache.set(command, response, keys) def delete_from_local_cache(self, command): if self.client_cache is not None: @@ -565,12 +565,13 @@ def delete_from_local_cache(self, command): # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): """Execute a command and return a parsed response""" + command_name = args[0] + keys = options.pop("keys", None) response_from_cache = self.get_from_local_cache(args) if response_from_cache is not None: return response_from_cache else: pool = self.connection_pool - command_name = args[0] conn = self.connection or pool.get_connection(command_name, **options) try: @@ -580,7 +581,7 @@ def execute_command(self, *args, **options): ), lambda error: self._disconnect_raise(conn, error), ) - self.add_to_local_cache(args, response) + self.add_to_local_cache(args, response, keys) finally: if not self.connection: pool.release(conn) diff --git a/redis/commands/core.py b/redis/commands/core.py index e73553e47e..18db7fef17 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1590,7 +1590,7 @@ def bitcount( raise DataError("Both start and end must be specified") if mode is not None: params.append(mode) - return self.execute_command("BITCOUNT", *params) + return self.execute_command("BITCOUNT", *params, keys=[key]) def bitfield( self: Union["Redis", "AsyncRedis"], @@ -1626,7 +1626,7 @@ def bitfield_ro( items = items or [] for encoding, offset in items: params.extend(["GET", encoding, offset]) - return self.execute_command("BITFIELD_RO", *params) + return self.execute_command("BITFIELD_RO", *params, keys=[key]) def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseT: """ @@ -1666,7 +1666,7 @@ def bitpos( if mode is not None: params.append(mode) - return self.execute_command("BITPOS", *params) + return self.execute_command("BITPOS", *params, keys=[key]) def copy( self, @@ -1733,7 +1733,7 @@ def exists(self, *names: KeyT) -> ResponseT: For more information see https://redis.io/commands/exists """ - return self.execute_command("EXISTS", *names) + return self.execute_command("EXISTS", *names, keys=names) __contains__ = exists @@ -1826,7 +1826,7 @@ def get(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/get """ - return self.execute_command("GET", name) + return self.execute_command("GET", name, keys=[name]) def getdel(self, name: KeyT) -> ResponseT: """ @@ -1920,7 +1920,7 @@ def getbit(self, name: KeyT, offset: int) -> ResponseT: For more information see https://redis.io/commands/getbit """ - return self.execute_command("GETBIT", name, offset) + return self.execute_command("GETBIT", name, offset, keys=[name]) def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ @@ -1929,7 +1929,7 @@ def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: For more information see https://redis.io/commands/getrange """ - return self.execute_command("GETRANGE", key, start, end) + return self.execute_command("GETRANGE", key, start, end, keys=[key]) def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ @@ -2012,6 +2012,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: options = {} if not args: options[EMPTY_RESPONSE] = [] + options["keys"] = keys return self.execute_command("MGET", *args, **options) def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: @@ -2458,14 +2459,14 @@ def strlen(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/strlen """ - return self.execute_command("STRLEN", name) + return self.execute_command("STRLEN", name, keys=[name]) def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ - return self.execute_command("SUBSTR", name, start, end) + return self.execute_command("SUBSTR", name, start, end, keys=[name]) def touch(self, *args: KeyT) -> ResponseT: """ @@ -2490,7 +2491,7 @@ def type(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/type """ - return self.execute_command("TYPE", name) + return self.execute_command("TYPE", name, keys=[name]) def watch(self, *names: KeyT) -> None: """ @@ -2543,7 +2544,7 @@ def lcs( pieces.extend(["MINMATCHLEN", minmatchlen]) if withmatchlen: pieces.append("WITHMATCHLEN") - return self.execute_command("LCS", *pieces) + return self.execute_command("LCS", *pieces, keys=[key1, key2]) class AsyncBasicKeyCommands(BasicKeyCommands): @@ -2682,7 +2683,7 @@ def lindex( For more information see https://redis.io/commands/lindex """ - return self.execute_command("LINDEX", name, index) + return self.execute_command("LINDEX", name, index, keys=[name]) def linsert( self, name: str, where: str, refvalue: str, value: str @@ -2704,7 +2705,7 @@ def llen(self, name: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/llen """ - return self.execute_command("LLEN", name) + return self.execute_command("LLEN", name, keys=[name]) def lpop( self, @@ -2751,7 +2752,7 @@ def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list For more information see https://redis.io/commands/lrange """ - return self.execute_command("LRANGE", name, start, end) + return self.execute_command("LRANGE", name, start, end, keys=[name]) def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: """ @@ -2874,7 +2875,7 @@ def lpos( if maxlen is not None: pieces.extend(["MAXLEN", maxlen]) - return self.execute_command("LPOS", *pieces) + return self.execute_command("LPOS", *pieces, keys=[name]) def sort( self, @@ -2946,6 +2947,7 @@ def sort( ) options = {"groups": len(get) if groups else None} + options["keys"] = [name] return self.execute_command("SORT", *pieces, **options) def sort_ro( @@ -3319,7 +3321,7 @@ def scard(self, name: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/scard """ - return self.execute_command("SCARD", name) + return self.execute_command("SCARD", name, keys=[name]) def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: """ @@ -3328,7 +3330,7 @@ def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: For more information see https://redis.io/commands/sdiff """ args = list_or_args(keys, args) - return self.execute_command("SDIFF", *args) + return self.execute_command("SDIFF", *args, keys=args) def sdiffstore( self, dest: str, keys: List, *args: List @@ -3349,7 +3351,7 @@ def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: For more information see https://redis.io/commands/sinter """ args = list_or_args(keys, args) - return self.execute_command("SINTER", *args) + return self.execute_command("SINTER", *args, keys=args) def sintercard( self, numkeys: int, keys: List[str], limit: int = 0 @@ -3364,7 +3366,7 @@ def sintercard( For more information see https://redis.io/commands/sintercard """ args = [numkeys, *keys, "LIMIT", limit] - return self.execute_command("SINTERCARD", *args) + return self.execute_command("SINTERCARD", *args, keys=keys) def sinterstore( self, dest: str, keys: List, *args: List @@ -3388,7 +3390,7 @@ def sismember( For more information see https://redis.io/commands/sismember """ - return self.execute_command("SISMEMBER", name, value) + return self.execute_command("SISMEMBER", name, value, keys=[name]) def smembers(self, name: str) -> Union[Awaitable[Set], Set]: """ @@ -3396,7 +3398,7 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: For more information see https://redis.io/commands/smembers """ - return self.execute_command("SMEMBERS", name) + return self.execute_command("SMEMBERS", name, keys=[name]) def smismember( self, name: str, values: List, *args: List @@ -3413,7 +3415,7 @@ def smismember( For more information see https://redis.io/commands/smismember """ args = list_or_args(values, args) - return self.execute_command("SMISMEMBER", name, *args) + return self.execute_command("SMISMEMBER", name, *args, keys=[name]) def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: """ @@ -3462,7 +3464,7 @@ def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: For more information see https://redis.io/commands/sunion """ args = list_or_args(keys, args) - return self.execute_command("SUNION", *args) + return self.execute_command("SUNION", *args, keys=args) def sunionstore( self, dest: str, keys: List, *args: List @@ -3820,7 +3822,7 @@ def xlen(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/xlen """ - return self.execute_command("XLEN", name) + return self.execute_command("XLEN", name, keys=[name]) def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: """ @@ -3830,7 +3832,7 @@ def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: For more information see https://redis.io/commands/xpending """ - return self.execute_command("XPENDING", name, groupname) + return self.execute_command("XPENDING", name, groupname, keys=[name]) def xpending_range( self, @@ -3919,7 +3921,7 @@ def xrange( pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XRANGE", name, *pieces) + return self.execute_command("XRANGE", name, *pieces, keys=[name]) def xread( self, @@ -3957,7 +3959,7 @@ def xread( keys, values = zip(*streams.items()) pieces.extend(keys) pieces.extend(values) - return self.execute_command("XREAD", *pieces) + return self.execute_command("XREAD", *pieces, keys=keys) def xreadgroup( self, @@ -4036,7 +4038,7 @@ def xrevrange( pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XREVRANGE", name, *pieces) + return self.execute_command("XREVRANGE", name, *pieces, keys=[name]) def xtrim( self, @@ -4175,7 +4177,7 @@ def zcard(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/zcard """ - return self.execute_command("ZCARD", name) + return self.execute_command("ZCARD", name, keys=[name]) def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ @@ -4184,7 +4186,7 @@ def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: For more information see https://redis.io/commands/zcount """ - return self.execute_command("ZCOUNT", name, min, max) + return self.execute_command("ZCOUNT", name, min, max, keys=[name]) def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: """ @@ -4196,7 +4198,7 @@ def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: pieces = [len(keys), *keys] if withscores: pieces.append("WITHSCORES") - return self.execute_command("ZDIFF", *pieces) + return self.execute_command("ZDIFF", *pieces, keys=keys) def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: """ @@ -4264,7 +4266,7 @@ def zintercard( For more information see https://redis.io/commands/zintercard """ args = [numkeys, *keys, "LIMIT", limit] - return self.execute_command("ZINTERCARD", *args) + return self.execute_command("ZINTERCARD", *args, keys=keys) def zlexcount(self, name, min, max): """ @@ -4273,7 +4275,7 @@ def zlexcount(self, name, min, max): For more information see https://redis.io/commands/zlexcount """ - return self.execute_command("ZLEXCOUNT", name, min, max) + return self.execute_command("ZLEXCOUNT", name, min, max, keys=[name]) def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: """ @@ -4456,6 +4458,7 @@ def _zrange( if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = [name] return self.execute_command(*pieces, **options) def zrange( @@ -4544,6 +4547,7 @@ def zrevrange( if withscores: pieces.append(b"WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = name return self.execute_command(*pieces, **options) def zrangestore( @@ -4618,7 +4622,7 @@ def zrangebylex( pieces = ["ZRANGEBYLEX", name, min, max] if start is not None and num is not None: pieces.extend([b"LIMIT", start, num]) - return self.execute_command(*pieces) + return self.execute_command(*pieces, keys=[name]) def zrevrangebylex( self, @@ -4642,7 +4646,7 @@ def zrevrangebylex( pieces = ["ZREVRANGEBYLEX", name, max, min] if start is not None and num is not None: pieces.extend(["LIMIT", start, num]) - return self.execute_command(*pieces) + return self.execute_command(*pieces, keys=[name]) def zrangebyscore( self, @@ -4676,6 +4680,7 @@ def zrangebyscore( if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = [name] return self.execute_command(*pieces, **options) def zrevrangebyscore( @@ -4710,6 +4715,7 @@ def zrevrangebyscore( if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = [name] return self.execute_command(*pieces, **options) def zrank( @@ -4727,8 +4733,8 @@ def zrank( For more information see https://redis.io/commands/zrank """ if withscore: - return self.execute_command("ZRANK", name, value, "WITHSCORE") - return self.execute_command("ZRANK", name, value) + return self.execute_command("ZRANK", name, value, "WITHSCORE", keys=[name]) + return self.execute_command("ZRANK", name, value, keys=[name]) def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -4786,8 +4792,10 @@ def zrevrank( For more information see https://redis.io/commands/zrevrank """ if withscore: - return self.execute_command("ZREVRANK", name, value, "WITHSCORE") - return self.execute_command("ZREVRANK", name, value) + return self.execute_command( + "ZREVRANK", name, value, "WITHSCORE", keys=[name] + ) + return self.execute_command("ZREVRANK", name, value, keys=[name]) def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: """ @@ -4795,7 +4803,7 @@ def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: For more information see https://redis.io/commands/zscore """ - return self.execute_command("ZSCORE", name, value) + return self.execute_command("ZSCORE", name, value, keys=[name]) def zunion( self, @@ -4842,7 +4850,7 @@ def zmscore(self, key: KeyT, members: List[str]) -> ResponseT: if not members: raise DataError("ZMSCORE members must be a non-empty list") pieces = [key] + members - return self.execute_command("ZMSCORE", *pieces) + return self.execute_command("ZMSCORE", *pieces, keys=[key]) def _zaggregate( self, @@ -4872,6 +4880,7 @@ def _zaggregate( raise DataError("aggregate can be sum, min or max.") if options.get("withscores", False): pieces.append(b"WITHSCORES") + options["keys"] = keys return self.execute_command(*pieces, **options) @@ -4933,7 +4942,7 @@ def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: For more information see https://redis.io/commands/hexists """ - return self.execute_command("HEXISTS", name, key) + return self.execute_command("HEXISTS", name, key, keys=[name]) def hget( self, name: str, key: str @@ -4943,7 +4952,7 @@ def hget( For more information see https://redis.io/commands/hget """ - return self.execute_command("HGET", name, key) + return self.execute_command("HGET", name, key, keys=[name]) def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: """ @@ -4951,7 +4960,7 @@ def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: For more information see https://redis.io/commands/hgetall """ - return self.execute_command("HGETALL", name) + return self.execute_command("HGETALL", name, keys=[name]) def hincrby( self, name: str, key: str, amount: int = 1 @@ -4979,7 +4988,7 @@ def hkeys(self, name: str) -> Union[Awaitable[List], List]: For more information see https://redis.io/commands/hkeys """ - return self.execute_command("HKEYS", name) + return self.execute_command("HKEYS", name, keys=[name]) def hlen(self, name: str) -> Union[Awaitable[int], int]: """ @@ -4987,7 +4996,7 @@ def hlen(self, name: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/hlen """ - return self.execute_command("HLEN", name) + return self.execute_command("HLEN", name, keys=[name]) def hset( self, @@ -5054,7 +5063,7 @@ def hmget(self, name: str, keys: List, *args: List) -> Union[Awaitable[List], Li For more information see https://redis.io/commands/hmget """ args = list_or_args(keys, args) - return self.execute_command("HMGET", name, *args) + return self.execute_command("HMGET", name, *args, keys=[name]) def hvals(self, name: str) -> Union[Awaitable[List], List]: """ @@ -5062,7 +5071,7 @@ def hvals(self, name: str) -> Union[Awaitable[List], List]: For more information see https://redis.io/commands/hvals """ - return self.execute_command("HVALS", name) + return self.execute_command("HVALS", name, keys=[name]) def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: """ @@ -5071,7 +5080,7 @@ def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/hstrlen """ - return self.execute_command("HSTRLEN", name, key) + return self.execute_command("HSTRLEN", name, key, keys=[name]) AsyncHashCommands = HashCommands @@ -5464,7 +5473,7 @@ def geodist( raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) - return self.execute_command("GEODIST", *pieces) + return self.execute_command("GEODIST", *pieces, keys=[name]) def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -5473,7 +5482,7 @@ def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: For more information see https://redis.io/commands/geohash """ - return self.execute_command("GEOHASH", name, *values) + return self.execute_command("GEOHASH", name, *values, keys=[name]) def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -5483,7 +5492,7 @@ def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: For more information see https://redis.io/commands/geopos """ - return self.execute_command("GEOPOS", name, *values) + return self.execute_command("GEOPOS", name, *values, keys=[name]) def georadius( self, @@ -5823,6 +5832,8 @@ def _geosearchgeneric( if kwargs[arg_name]: pieces.append(byte_repr) + kwargs["keys"] = [args[0] if command == "GEOSEARCH" else args[1]] + return self.execute_command(command, *pieces, **kwargs) diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 0f92e0d6c9..ef0cb205a5 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -49,7 +49,7 @@ def arrindex( if stop is not None: pieces.append(stop) - return self.execute_command("JSON.ARRINDEX", *pieces) + return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name]) def arrinsert( self, name: str, path: str, index: int, *args: List[JsonType] @@ -72,7 +72,7 @@ def arrlen( For more information see `JSON.ARRLEN `_. """ # noqa - return self.execute_command("JSON.ARRLEN", name, str(path)) + return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name]) def arrpop( self, @@ -102,14 +102,14 @@ def type(self, name: str, path: Optional[str] = Path.root_path()) -> List[str]: For more information see `JSON.TYPE `_. """ # noqa - return self.execute_command("JSON.TYPE", name, str(path)) + return self.execute_command("JSON.TYPE", name, str(path), keys=[name]) def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List: """Return the JSON value under ``path`` at key ``name``. For more information see `JSON.RESP `_. """ # noqa - return self.execute_command("JSON.RESP", name, str(path)) + return self.execute_command("JSON.RESP", name, str(path), keys=[name]) def objkeys( self, name: str, path: Optional[str] = Path.root_path() @@ -119,7 +119,7 @@ def objkeys( For more information see `JSON.OBJKEYS `_. """ # noqa - return self.execute_command("JSON.OBJKEYS", name, str(path)) + return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name]) def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int: """Return the length of the dictionary JSON value under ``path`` at key @@ -127,7 +127,7 @@ def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int: For more information see `JSON.OBJLEN `_. """ # noqa - return self.execute_command("JSON.OBJLEN", name, str(path)) + return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name]) def numincrby(self, name: str, path: str, number: int) -> str: """Increment the numeric (integer or floating point) JSON value under @@ -197,7 +197,7 @@ def get( # Handle case where key doesn't exist. The JSONDecoder would raise a # TypeError exception since it can't decode None try: - return self.execute_command("JSON.GET", *pieces) + return self.execute_command("JSON.GET", *pieces, keys=[name]) except TypeError: return None @@ -211,7 +211,7 @@ def mget(self, keys: List[str], path: str) -> List[JsonType]: pieces = [] pieces += keys pieces.append(str(path)) - return self.execute_command("JSON.MGET", *pieces) + return self.execute_command("JSON.MGET", *pieces, keys=keys) def set( self, @@ -364,7 +364,7 @@ def strlen(self, name: str, path: Optional[str] = None) -> List[Union[int, None] pieces = [name] if path is not None: pieces.append(str(path)) - return self.execute_command("JSON.STRLEN", *pieces) + return self.execute_command("JSON.STRLEN", *pieces, keys=[name]) def toggle( self, name: str, path: Optional[str] = Path.root_path() diff --git a/redis/commands/timeseries/commands.py b/redis/commands/timeseries/commands.py index 13e3cdf498..1cb183d087 100644 --- a/redis/commands/timeseries/commands.py +++ b/redis/commands/timeseries/commands.py @@ -425,7 +425,7 @@ def range( bucket_timestamp, empty, ) - return self.execute_command(RANGE_CMD, *params) + return self.execute_command(RANGE_CMD, *params, keys=[key]) def revrange( self, @@ -497,7 +497,7 @@ def revrange( bucket_timestamp, empty, ) - return self.execute_command(REVRANGE_CMD, *params) + return self.execute_command(REVRANGE_CMD, *params, keys=[key]) def __mrange_params( self, @@ -721,7 +721,7 @@ def get(self, key: KeyT, latest: Optional[bool] = False): """ # noqa params = [key] self._append_latest(params, latest) - return self.execute_command(GET_CMD, *params) + return self.execute_command(GET_CMD, *params, keys=[key]) def mget( self, @@ -761,7 +761,7 @@ def info(self, key: KeyT): For more information: https://redis.io/commands/ts.info/ """ # noqa - return self.execute_command(INFO_CMD, key) + return self.execute_command(INFO_CMD, key, keys=[key]) def queryindex(self, filters: List[str]): """# noqa diff --git a/redis/utils.py b/redis/utils.py index 271af61ee9..f913ce4e99 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -166,6 +166,7 @@ def get_lib_version(): "FT.SYNDUMP", "FT.TAGVALS", "FT._ALIASADDIFNX", + "FT._ALIASDELIFX", "BF.CARD", "BF.DEBUG", "BF.EXISTS", @@ -181,6 +182,7 @@ def get_lib_version(): "CF.SCANDUMP", "CMS.INFO", "CMS.QUERY", + "DUMP", "EXPIRETIME", "HRANDFIELD", "JSON.DEBUG", @@ -201,6 +203,7 @@ def get_lib_version(): "TOPK.INFO", "TOPK.LIST", "TOPK.QUERY", + "TOUCH", "TTL", ] @@ -209,13 +212,7 @@ def get_lib_version(): "BITCOUNT", "BITFIELD_RO", "BITPOS", - "DBSIZE", - "DUMP", - "EVALSHA_RO", - "EVAL_RO", "EXISTS", - "EXPIRETIME", - "FCALL_RO", "GEODIST", "GEOHASH", "GEOPOS", @@ -231,28 +228,23 @@ def get_lib_version(): "HKEYS", "HLEN", "HMGET", - "HRANDFIELD", - "HSCAN", "HSTRLEN", "HVALS", - "KEYS", + "JSON.ARRINDEX", + "JSON.ARRLEN", + "JSON.GET", + "JSON.MGET", + "JSON.OBJKEYS", + "JSON.OBJLEN", + "JSON.RESP", + "JSON.STRLEN", + "JSON.TYPE", "LCS", "LINDEX", "LLEN", - "LOLWUT", "LPOS", "LRANGE", - "MEMORY USAGE", "MGET", - "OBJECT ENCODING", - "OBJECT FREQ", - "OBJECT IDLETIME", - "OBJECT REFCOUNT", - "PEXPIRETIME", - "PFCOUNT", - "PTTL", - "RANDOMKEY", - "SCAN", "SCARD", "SDIFF", "SINTER", @@ -261,17 +253,14 @@ def get_lib_version(): "SMEMBERS", "SMISMEMBER", "SORT_RO", - "SRANDMEMBER", - "SSCAN", "STRLEN", "SUBSTR", "SUNION", - "TOUCH", - "TTL", + "TS.GET", + "TS.INFO", + "TS.RANGE", + "TS.REVRANGE", "TYPE", - "XINFO CONSUMERS", - "XINFO GROUPS", - "XINFO STREAM", "XLEN", "XPENDING", "XRANGE", @@ -284,7 +273,6 @@ def get_lib_version(): "ZINTERCARD", "ZLEXCOUNT", "ZMSCORE", - "ZRANDMEMBER", "ZRANGE", "ZRANGEBYLEX", "ZRANGEBYSCORE", @@ -293,7 +281,6 @@ def get_lib_version(): "ZREVRANGEBYLEX", "ZREVRANGEBYSCORE", "ZREVRANK", - "ZSCAN", "ZSCORE", "ZUNION", ] From 2973a5379632f7f80012ebe8dc4cb43618d52e5d Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 16 Nov 2023 00:40:28 +0200 Subject: [PATCH 3/7] fix review comments --- redis/cache.py | 291 ++++++++++++++++++++++++++++++++++---- redis/client.py | 45 ++++-- redis/commands/cluster.py | 2 +- redis/commands/core.py | 4 +- redis/typing.py | 1 + redis/utils.py | 139 ------------------ 6 files changed, 295 insertions(+), 187 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 2a048e9bda..6ecc619367 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,10 +1,185 @@ import random import time from collections import OrderedDict, defaultdict +from enum import Enum +from typing import List + +from redis.typing import KeyT, ResponseT + +DEFAULT_EVICTION_POLICY = "lru" + + +DEFAULT_BLACKLIST = [ + "BF.CARD", + "BF.DEBUG", + "BF.EXISTS", + "BF.INFO", + "BF.MEXISTS", + "BF.SCANDUMP", + "CF.COMPACT", + "CF.COUNT", + "CF.DEBUG", + "CF.EXISTS", + "CF.INFO", + "CF.MEXISTS", + "CF.SCANDUMP", + "CMS.INFO", + "CMS.QUERY", + "DUMP", + "EXPIRETIME", + "FT.AGGREGATE", + "FT.ALIASADD", + "FT.ALIASDEL", + "FT.ALIASUPDATE", + "FT.CURSOR", + "FT.EXPLAIN", + "FT.EXPLAINCLI", + "FT.GET", + "FT.INFO", + "FT.MGET", + "FT.PROFILE", + "FT.SEARCH", + "FT.SPELLCHECK", + "FT.SUGGET", + "FT.SUGLEN", + "FT.SYNDUMP", + "FT.TAGVALS", + "FT._ALIASADDIFNX", + "FT._ALIASDELIFX", + "HRANDFIELD", + "JSON.DEBUG", + "PEXPIRETIME", + "PFCOUNT", + "PTTL", + "SRANDMEMBER", + "TDIGEST.BYRANK", + "TDIGEST.BYREVRANK", + "TDIGEST.CDF", + "TDIGEST.INFO", + "TDIGEST.MAX", + "TDIGEST.MIN", + "TDIGEST.QUANTILE", + "TDIGEST.RANK", + "TDIGEST.REVRANK", + "TDIGEST.TRIMMED_MEAN", + "TOPK.INFO", + "TOPK.LIST", + "TOPK.QUERY", + "TOUCH", + "TTL", +] + + +DEFAULT_WHITELIST = [ + "BITCOUNT", + "BITFIELD_RO", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUSBYMEMBER_RO", + "GEORADIUS_RO", + "GEOSEARCH", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "JSON.ARRINDEX", + "JSON.ARRLEN", + "JSON.GET", + "JSON.MGET", + "JSON.OBJKEYS", + "JSON.OBJLEN", + "JSON.RESP", + "JSON.STRLEN", + "JSON.TYPE", + "LCS", + "LINDEX", + "LLEN", + "LPOS", + "LRANGE", + "MGET", + "SCARD", + "SDIFF", + "SINTER", + "SINTERCARD", + "SISMEMBER", + "SMEMBERS", + "SMISMEMBER", + "SORT_RO", + "STRLEN", + "SUBSTR", + "SUNION", + "TS.GET", + "TS.INFO", + "TS.RANGE", + "TS.REVRANGE", + "TYPE", + "XLEN", + "XPENDING", + "XRANGE", + "XREAD", + "XREVRANGE", + "ZCARD", + "ZCOUNT", + "ZDIFF", + "ZINTER", + "ZINTERCARD", + "ZLEXCOUNT", + "ZMSCORE", + "ZRANGE", + "ZRANGEBYLEX", + "ZRANGEBYSCORE", + "ZRANK", + "ZREVRANGE", + "ZREVRANGEBYLEX", + "ZREVRANGEBYSCORE", + "ZREVRANK", + "ZSCORE", + "ZUNION", +] + +_RESPONSE = "response" +_KEYS = "keys" +_CTIME = "ctime" +_ACCESS_COUNT = "access_count" + + +class EvictionPolicy(Enum): + LRU = "lru" + LFU = "lfu" + RANDOM = "random" class _Cache: - def __init__(self, max_size: int, ttl: int, eviction_policy: str, **kwargs): + """ + A caching mechanism for storing redis commands and their responses. + + Args: + max_size (int): The maximum number of commands to be stored in the cache. + ttl (int): The time-to-live for each command in seconds. + eviction_policy (EvictionPolicy): The eviction policy to use for removing commands when the cache is full. + + Attributes: + max_size (int): The maximum number of commands to be stored in the cache. + ttl (int): The time-to-live for each command in seconds. + eviction_policy (EvictionPolicy): The eviction policy used for cache management. + cache (OrderedDict): The ordered dictionary to store commands and their metadata. + key_commands_map (defaultdict): A mapping of keys to the set of commands that use each key. + commands_ttl_list (list): A list to keep track of the commands in the order they were added. # noqa + """ + + def __init__( + self, max_size: int, ttl: int, eviction_policy: EvictionPolicy, **kwargs + ): self.max_size = max_size self.ttl = ttl self.eviction_policy = eviction_policy @@ -12,33 +187,51 @@ def __init__(self, max_size: int, ttl: int, eviction_policy: str, **kwargs): self.key_commands_map = defaultdict(set) self.commands_ttl_list = [] - def set(self, command, response, keys_in_command): + def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): + """ + Set a redis command and its response in the cache. + + Args: + command (str): The redis command. + response (ResponseT): The response associated with the command. + keys_in_command (List[KeyT]): The list of keys used in the command. + """ if len(self.cache) >= self.max_size: self._evict() self.cache[command] = { - "response": response, - "keys": keys_in_command, - "created_time": time.monotonic(), + _RESPONSE: response, + _KEYS: keys_in_command, + _CTIME: time.monotonic(), + _ACCESS_COUNT: 0, # Used only for LFU } - if self.eviction_policy == "lfu": - self.cache[command]["access_count"] = 0 self._update_key_commands_map(keys_in_command, command) self.commands_ttl_list.append(command) - def get(self, command): + def get(self, command: str) -> ResponseT: + """ + Get the response for a redis command from the cache. + + Args: + command (str): The redis command. + + Returns: + ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa + """ if command in self.cache: if self._is_expired(command): - del self.cache[command] - keys_in_command = self.cache[command]["keys"] - self._del_key_commands_map(keys_in_command, command) - return None + self.delete(command) self._update_access(command) return self.cache[command]["response"] - return None - def delete(self, command): + def delete(self, command: str): + """ + Delete a redis command and its metadata from the cache. + + Args: + command (str): The redis command to be deleted. + """ if command in self.cache: - keys_in_command = self.cache[command]["keys"] + keys_in_command = self.cache[command].get("keys") self._del_key_commands_map(keys_in_command, command) self.commands_ttl_list.remove(command) del self.cache[command] @@ -47,49 +240,87 @@ def delete_many(self, commands): pass def flush(self): + """Clear the entire cache, removing all redis commands and metadata.""" self.cache.clear() self.key_commands_map.clear() self.commands_ttl_list = [] - def _is_expired(self, command): + def _is_expired(self, command: str) -> bool: + """ + Check if a redis command has expired based on its time-to-live. + + Args: + command (str): The redis command. + + Returns: + bool: True if the command has expired, False otherwise. + """ if self.ttl == 0: return False - return time.monotonic() - self.cache[command]["created_time"] > self.ttl + return time.monotonic() - self.cache[command]["ctime"] > self.ttl + + def _update_access(self, command: str): + """ + Update the access information for a redis command based on the eviction policy. - def _update_access(self, command): - if self.eviction_policy == "lru": + Args: + command (str): The redis command. + """ + if self.eviction_policy == EvictionPolicy.LRU: self.cache.move_to_end(command) - elif self.eviction_policy == "lfu": + elif self.eviction_policy == EvictionPolicy.LFU: self.cache[command]["access_count"] = ( self.cache.get(command, {}).get("access_count", 0) + 1 ) self.cache.move_to_end(command) - elif self.eviction_policy == "random": + elif self.eviction_policy == EvictionPolicy.RANDOM: pass # Random eviction doesn't require updates def _evict(self): + """Evict a redis command from the cache based on the eviction policy.""" if self._is_expired(self.commands_ttl_list[0]): self.delete(self.commands_ttl_list[0]) - elif self.eviction_policy == "lru": + elif self.eviction_policy == EvictionPolicy.LRU: self.cache.popitem(last=False) - elif self.eviction_policy == "lfu": + elif self.eviction_policy == EvictionPolicy.LFU: min_access_command = min( self.cache, key=lambda k: self.cache[k].get("access_count", 0) ) self.cache.pop(min_access_command) - elif self.eviction_policy == "random": + elif self.eviction_policy == EvictionPolicy.RANDOM: random_command = random.choice(list(self.cache.keys())) self.cache.pop(random_command) - def _update_key_commands_map(self, keys, command): + def _update_key_commands_map(self, keys: List[KeyT], command: str): + """ + Update the key_commands_map with command that uses the keys. + + Args: + keys (List[KeyT]): The list of keys used in the command. + command (str): The redis command. + """ for key in keys: self.key_commands_map[key].add(command) - def _del_key_commands_map(self, keys, command): + def _del_key_commands_map(self, keys: List[KeyT], command: str): + """ + Remove a redis command from the key_commands_map. + + Args: + keys (List[KeyT]): The list of keys used in the redis command. + command (str): The redis command. + """ for key in keys: self.key_commands_map[key].remove(command) - def invalidate(self, key): - if key in self.key_commands_map: - for command in self.key_commands_map[key]: - self.delete(command) + def invalidate(self, key: KeyT): + """ + Invalidate (delete) all redis commands associated with a specific key. + + Args: + key (KeyT): The key to be invalidated. + """ + if key not in self.key_commands_map: + return + for command in self.key_commands_map[key]: + self.delete(command) diff --git a/redis/client.py b/redis/client.py index 55bda19b48..2f21f33bd1 100755 --- a/redis/client.py +++ b/redis/client.py @@ -13,7 +13,12 @@ _RedisCallbacksRESP3, bool_ok, ) -from redis.cache import _Cache +from redis.cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _Cache, +) from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -33,9 +38,8 @@ ) from redis.lock import Lock from redis.retry import Retry +from redis.typing import KeysT, ResponseT from redis.utils import ( - DEFAULT_BLACKLIST, - DEFAULT_WHITELIST, HIREDIS_AVAILABLE, _set_info_logger, get_lib_version, @@ -206,11 +210,11 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - client_caching: bool = False, + cache_enable: bool = False, client_cache: Optional[_Cache] = None, cache_max_size: int = 100, cache_ttl: int = 0, - cache_eviction_policy: str = "lru", + cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, cache_blacklist: List[str] = DEFAULT_BLACKLIST, cache_whitelist: List[str] = DEFAULT_WHITELIST, ) -> None: @@ -321,10 +325,11 @@ def __init__( self.response_callbacks.update(_RedisCallbacksRESP2) self.client_cache = client_cache - self.cache_blacklist = cache_blacklist - self.cache_whitelist = cache_whitelist - if client_caching: + if cache_enable: self.client_cache = _Cache(cache_max_size, cache_ttl, cache_eviction_policy) + if self.client_cache is not None: + self.cache_blacklist = cache_blacklist + self.cache_whitelist = cache_whitelist def __repr__(self) -> str: return f"{type(self).__name__}<{repr(self.connection_pool)}>" @@ -541,7 +546,10 @@ def _disconnect_raise(self, conn, error): ): raise error - def get_from_local_cache(self, command): + def _get_from_local_cache(self, command: str): + """ + If the command is in the local cache, return the response + """ if ( self.client_cache is None or command[0] in self.cache_blacklist @@ -550,7 +558,11 @@ def get_from_local_cache(self, command): return None return self.client_cache.get(command) - def add_to_local_cache(self, command, response, keys): + def _add_to_local_cache(self, command: str, response: ResponseT, keys: List[KeysT]): + """ + Add the command and response to the local cache if the command + is allowed to be cached + """ if ( self.client_cache is not None and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) @@ -558,16 +570,21 @@ def add_to_local_cache(self, command, response, keys): ): self.client_cache.set(command, response, keys) - def delete_from_local_cache(self, command): - if self.client_cache is not None: + def delete_from_local_cache(self, command: str): + """ + Delete the command from the local cache + """ + try: self.client_cache.delete(command) + except AttributeError: + pass # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): """Execute a command and return a parsed response""" command_name = args[0] keys = options.pop("keys", None) - response_from_cache = self.get_from_local_cache(args) + response_from_cache = self._get_from_local_cache(args) if response_from_cache is not None: return response_from_cache else: @@ -581,7 +598,7 @@ def execute_command(self, *args, **options): ), lambda error: self._disconnect_raise(conn, error), ) - self.add_to_local_cache(args, response, keys) + self._add_to_local_cache(args, response, keys) finally: if not self.connection: pool.release(conn) diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index 14b8741443..8637f6c247 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -23,6 +23,7 @@ KeysT, KeyT, PatternT, + ResponseT, ) from .core import ( @@ -40,7 +41,6 @@ ManagementCommands, ModuleCommands, PubSubCommands, - ResponseT, ScriptCommands, ) from .helpers import list_or_args diff --git a/redis/commands/core.py b/redis/commands/core.py index 18db7fef17..f97724d030 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5,7 +5,6 @@ import warnings from typing import ( TYPE_CHECKING, - Any, AsyncIterator, Awaitable, Callable, @@ -37,6 +36,7 @@ KeysT, KeyT, PatternT, + ResponseT, ScriptTextT, StreamIdT, TimeoutSecT, @@ -49,8 +49,6 @@ from redis.asyncio.client import Redis as AsyncRedis from redis.client import Redis -ResponseT = Union[Awaitable, Any] - class ACLCommands(CommandsProtocol): """ diff --git a/redis/typing.py b/redis/typing.py index 56a1e99ba7..d1cd5568a3 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -33,6 +33,7 @@ PatternT = _StringLikeT # Patterns matched against keys, fields etc FieldT = EncodableT # Fields within hash tables, streams and geo commands KeysT = Union[KeyT, Iterable[KeyT]] +ResponseT = Union[Awaitable, Any] ChannelT = _StringLikeT GroupT = _StringLikeT # Consumer group ConsumerT = _StringLikeT # Consumer name diff --git a/redis/utils.py b/redis/utils.py index f913ce4e99..01fdfed7a2 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -145,142 +145,3 @@ def get_lib_version(): except metadata.PackageNotFoundError: libver = "99.99.99" return libver - - -DEFAULT_BLACKLIST = [ - "FT.AGGREGATE", - "FT.ALIASADD", - "FT.ALIASDEL", - "FT.ALIASUPDATE", - "FT.CURSOR", - "FT.EXPLAIN", - "FT.EXPLAINCLI", - "FT.GET", - "FT.INFO", - "FT.MGET", - "FT.PROFILE", - "FT.SEARCH", - "FT.SPELLCHECK", - "FT.SUGGET", - "FT.SUGLEN", - "FT.SYNDUMP", - "FT.TAGVALS", - "FT._ALIASADDIFNX", - "FT._ALIASDELIFX", - "BF.CARD", - "BF.DEBUG", - "BF.EXISTS", - "BF.INFO", - "BF.MEXISTS", - "BF.SCANDUMP", - "CF.COMPACT", - "CF.COUNT", - "CF.DEBUG", - "CF.EXISTS", - "CF.INFO", - "CF.MEXISTS", - "CF.SCANDUMP", - "CMS.INFO", - "CMS.QUERY", - "DUMP", - "EXPIRETIME", - "HRANDFIELD", - "JSON.DEBUG", - "PEXPIRETIME", - "PFCOUNT", - "PTTL", - "SRANDMEMBER", - "TDIGEST.BYRANK", - "TDIGEST.BYREVRANK", - "TDIGEST.CDF", - "TDIGEST.INFO", - "TDIGEST.MAX", - "TDIGEST.MIN", - "TDIGEST.QUANTILE", - "TDIGEST.RANK", - "TDIGEST.REVRANK", - "TDIGEST.TRIMMED_MEAN", - "TOPK.INFO", - "TOPK.LIST", - "TOPK.QUERY", - "TOUCH", - "TTL", -] - - -DEFAULT_WHITELIST = [ - "BITCOUNT", - "BITFIELD_RO", - "BITPOS", - "EXISTS", - "GEODIST", - "GEOHASH", - "GEOPOS", - "GEORADIUSBYMEMBER_RO", - "GEORADIUS_RO", - "GEOSEARCH", - "GET", - "GETBIT", - "GETRANGE", - "HEXISTS", - "HGET", - "HGETALL", - "HKEYS", - "HLEN", - "HMGET", - "HSTRLEN", - "HVALS", - "JSON.ARRINDEX", - "JSON.ARRLEN", - "JSON.GET", - "JSON.MGET", - "JSON.OBJKEYS", - "JSON.OBJLEN", - "JSON.RESP", - "JSON.STRLEN", - "JSON.TYPE", - "LCS", - "LINDEX", - "LLEN", - "LPOS", - "LRANGE", - "MGET", - "SCARD", - "SDIFF", - "SINTER", - "SINTERCARD", - "SISMEMBER", - "SMEMBERS", - "SMISMEMBER", - "SORT_RO", - "STRLEN", - "SUBSTR", - "SUNION", - "TS.GET", - "TS.INFO", - "TS.RANGE", - "TS.REVRANGE", - "TYPE", - "XLEN", - "XPENDING", - "XRANGE", - "XREAD", - "XREVRANGE", - "ZCARD", - "ZCOUNT", - "ZDIFF", - "ZINTER", - "ZINTERCARD", - "ZLEXCOUNT", - "ZMSCORE", - "ZRANGE", - "ZRANGEBYLEX", - "ZRANGEBYSCORE", - "ZRANK", - "ZREVRANGE", - "ZREVRANGEBYLEX", - "ZREVRANGEBYSCORE", - "ZREVRANK", - "ZSCORE", - "ZUNION", -] From ec73468359cf8ca9e9cf5b5204661e565dc6c398 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 16 Nov 2023 00:54:31 +0200 Subject: [PATCH 4/7] return respone in execute_command --- redis/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/redis/client.py b/redis/client.py index 2f21f33bd1..ea8430792f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -599,6 +599,7 @@ def execute_command(self, *args, **options): lambda error: self._disconnect_raise(conn, error), ) self._add_to_local_cache(args, response, keys) + return response finally: if not self.connection: pool.release(conn) From a8780fb36907782537bcfee563e090b0efcaf4b0 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 16 Nov 2023 01:29:39 +0200 Subject: [PATCH 5/7] fix tests --- redis/asyncio/client.py | 2 ++ redis/asyncio/cluster.py | 2 ++ redis/asyncio/sentinel.py | 1 + redis/client.py | 1 + redis/cluster.py | 2 ++ redis/commands/timeseries/utils.py | 2 +- redis/sentinel.py | 1 + 7 files changed, 10 insertions(+), 1 deletion(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index acc89941f2..23a825c607 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -597,6 +597,7 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() + options.pop("keys", None) # the keys is used only for client side caching pool = self.connection_pool command_name = args[0] conn = self.connection or await pool.get_connection(command_name, **options) @@ -1275,6 +1276,7 @@ def multi(self): def execute_command( self, *args, **kwargs ) -> Union["Pipeline", Awaitable["Pipeline"]]: + kwargs.pop("keys", None) # the keys is used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 636144a9c7..f77a861f68 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -682,6 +682,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: :raises RedisClusterException: if target_nodes is not provided & the command can't be mapped to a slot """ + kwargs.pop("keys", None) # the keys is used only for client side caching command = args[0] target_nodes = [] target_nodes_specified = False @@ -1447,6 +1448,7 @@ def execute_command( or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - Rest of the kwargs are passed to the Redis connection """ + kwargs.pop("keys", None) # the keys is used only for client side caching self._command_stack.append( PipelineCommand(len(self._command_stack), *args, **kwargs) ) diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 6834fb194f..3a829178bc 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -220,6 +220,7 @@ async def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ + kwargs.pop("keys", None) # the keys is used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") diff --git a/redis/client.py b/redis/client.py index ea8430792f..5f47597024 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1314,6 +1314,7 @@ def multi(self) -> None: self.explicit_transaction = True def execute_command(self, *args, **kwargs): + kwargs.pop("keys", None) # the keys is used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/cluster.py b/redis/cluster.py index 873d586c4a..e5555e5883 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1060,6 +1060,7 @@ def execute_command(self, *args, **kwargs): list dict """ + kwargs.pop("keys", None) # the keys is used only for client side caching target_nodes_specified = False is_default_node = False target_nodes = None @@ -1962,6 +1963,7 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ + kwargs.pop("keys", None) # the keys is used only for client side caching return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): diff --git a/redis/commands/timeseries/utils.py b/redis/commands/timeseries/utils.py index c49b040271..12ed656277 100644 --- a/redis/commands/timeseries/utils.py +++ b/redis/commands/timeseries/utils.py @@ -5,7 +5,7 @@ def list_to_dict(aList): return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))} -def parse_range(response): +def parse_range(response, **kwargs): """Parse range response. Used by TS.RANGE and TS.REVRANGE.""" return [tuple((r[0], float(r[1]))) for r in response] diff --git a/redis/sentinel.py b/redis/sentinel.py index 41f308d1ee..ede141a21a 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -244,6 +244,7 @@ def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ + kwargs.pop("keys", None) # the keys is used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") From e06d140835f99a9949a6bf0a9bd4bc20e790b458 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 16 Nov 2023 10:10:28 +0200 Subject: [PATCH 6/7] fix comments --- redis/__init__.py | 4 ++-- redis/asyncio/client.py | 4 ++-- redis/asyncio/cluster.py | 4 ++-- redis/asyncio/sentinel.py | 2 +- redis/cache.py | 2 +- redis/client.py | 8 ++++---- redis/cluster.py | 4 ++-- redis/sentinel.py | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/redis/__init__.py b/redis/__init__.py index 9892068e8e..7bf6839453 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -2,7 +2,7 @@ from redis import asyncio # noqa from redis.backoff import default_backoff -from redis.cache import _Cache +from redis.cache import _LocalChace from redis.client import Redis, StrictRedis from redis.cluster import RedisCluster from redis.connection import ( @@ -62,7 +62,7 @@ def int_or_str(value): VERSION = tuple([99, 99, 99]) __all__ = [ - "_Cache", + "_LocalChace", "AuthenticationError", "AuthenticationWrongNumberOfArgsError", "BlockingConnectionPool", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 23a825c607..8a8f54dc9c 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -597,7 +597,7 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() - options.pop("keys", None) # the keys is used only for client side caching + options.pop("keys", None) # the keys are used only for client side caching pool = self.connection_pool command_name = args[0] conn = self.connection or await pool.get_connection(command_name, **options) @@ -1276,7 +1276,7 @@ def multi(self): def execute_command( self, *args, **kwargs ) -> Union["Pipeline", Awaitable["Pipeline"]]: - kwargs.pop("keys", None) # the keys is used only for client side caching + kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index f77a861f68..ebc7e4a4cb 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -682,7 +682,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: :raises RedisClusterException: if target_nodes is not provided & the command can't be mapped to a slot """ - kwargs.pop("keys", None) # the keys is used only for client side caching + kwargs.pop("keys", None) # the keys are used only for client side caching command = args[0] target_nodes = [] target_nodes_specified = False @@ -1448,7 +1448,7 @@ def execute_command( or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - Rest of the kwargs are passed to the Redis connection """ - kwargs.pop("keys", None) # the keys is used only for client side caching + kwargs.pop("keys", None) # the keys are used only for client side caching self._command_stack.append( PipelineCommand(len(self._command_stack), *args, **kwargs) ) diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 3a829178bc..56f5e9d651 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -220,7 +220,7 @@ async def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - kwargs.pop("keys", None) # the keys is used only for client side caching + kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") diff --git a/redis/cache.py b/redis/cache.py index 6ecc619367..5a689d0ebd 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -159,7 +159,7 @@ class EvictionPolicy(Enum): RANDOM = "random" -class _Cache: +class _LocalChace: """ A caching mechanism for storing redis commands and their responses. diff --git a/redis/client.py b/redis/client.py index 5f47597024..952f1cd04a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -17,7 +17,7 @@ DEFAULT_BLACKLIST, DEFAULT_EVICTION_POLICY, DEFAULT_WHITELIST, - _Cache, + _LocalChace, ) from redis.commands import ( CoreCommands, @@ -211,7 +211,7 @@ def __init__( credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, cache_enable: bool = False, - client_cache: Optional[_Cache] = None, + client_cache: Optional[_LocalChace] = None, cache_max_size: int = 100, cache_ttl: int = 0, cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, @@ -326,7 +326,7 @@ def __init__( self.client_cache = client_cache if cache_enable: - self.client_cache = _Cache(cache_max_size, cache_ttl, cache_eviction_policy) + self.client_cache = _LocalChace(cache_max_size, cache_ttl, cache_eviction_policy) if self.client_cache is not None: self.cache_blacklist = cache_blacklist self.cache_whitelist = cache_whitelist @@ -1314,7 +1314,7 @@ def multi(self) -> None: self.explicit_transaction = True def execute_command(self, *args, **kwargs): - kwargs.pop("keys", None) # the keys is used only for client side caching + kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/cluster.py b/redis/cluster.py index e5555e5883..4de11b4e8c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1060,7 +1060,7 @@ def execute_command(self, *args, **kwargs): list dict """ - kwargs.pop("keys", None) # the keys is used only for client side caching + kwargs.pop("keys", None) # the keys are used only for client side caching target_nodes_specified = False is_default_node = False target_nodes = None @@ -1963,7 +1963,7 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ - kwargs.pop("keys", None) # the keys is used only for client side caching + kwargs.pop("keys", None) # the keys are used only for client side caching return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): diff --git a/redis/sentinel.py b/redis/sentinel.py index ede141a21a..a1ae5c5275 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -244,7 +244,7 @@ def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - kwargs.pop("keys", None) # the keys is used only for client side caching + kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") From 63bb591c2464d005834fbe2debee6841a4156429 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 16 Nov 2023 10:17:19 +0200 Subject: [PATCH 7/7] linters --- redis/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index 952f1cd04a..2ff3706d16 100755 --- a/redis/client.py +++ b/redis/client.py @@ -326,7 +326,9 @@ def __init__( self.client_cache = client_cache if cache_enable: - self.client_cache = _LocalChace(cache_max_size, cache_ttl, cache_eviction_policy) + self.client_cache = _LocalChace( + cache_max_size, cache_ttl, cache_eviction_policy + ) if self.client_cache is not None: self.cache_blacklist = cache_blacklist self.cache_whitelist = cache_whitelist