diff --git a/redis/__init__.py b/redis/__init__.py index 495d2d99bb..7bf6839453 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -2,6 +2,7 @@ from redis import asyncio # noqa from redis.backoff import default_backoff +from redis.cache import _LocalChace from redis.client import Redis, StrictRedis from redis.cluster import RedisCluster from redis.connection import ( @@ -61,6 +62,7 @@ def int_or_str(value): VERSION = tuple([99, 99, 99]) __all__ = [ + "_LocalChace", "AuthenticationError", "AuthenticationWrongNumberOfArgsError", "BlockingConnectionPool", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index acc89941f2..8a8f54dc9c 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -597,6 +597,7 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() + options.pop("keys", None) # the keys are used only for client side caching pool = self.connection_pool command_name = args[0] conn = self.connection or await pool.get_connection(command_name, **options) @@ -1275,6 +1276,7 @@ def multi(self): def execute_command( self, *args, **kwargs ) -> Union["Pipeline", Awaitable["Pipeline"]]: + kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 636144a9c7..ebc7e4a4cb 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -682,6 +682,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: :raises RedisClusterException: if target_nodes is not provided & the command can't be mapped to a slot """ + kwargs.pop("keys", None) # the keys are used only for client side caching command = args[0] target_nodes = [] target_nodes_specified = False @@ -1447,6 +1448,7 @@ def execute_command( or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - Rest of the kwargs are passed to the Redis connection """ + kwargs.pop("keys", None) # the keys are used only for client side caching self._command_stack.append( PipelineCommand(len(self._command_stack), *args, **kwargs) ) diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 6834fb194f..56f5e9d651 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -220,6 +220,7 @@ async def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ + kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") diff --git a/redis/cache.py b/redis/cache.py new file mode 100644 index 0000000000..5a689d0ebd --- /dev/null +++ b/redis/cache.py @@ -0,0 +1,326 @@ +import random +import time +from collections import OrderedDict, defaultdict +from enum import Enum +from typing import List + +from redis.typing import KeyT, ResponseT + +DEFAULT_EVICTION_POLICY = "lru" + + +DEFAULT_BLACKLIST = [ + "BF.CARD", + "BF.DEBUG", + "BF.EXISTS", + "BF.INFO", + "BF.MEXISTS", + "BF.SCANDUMP", + "CF.COMPACT", + "CF.COUNT", + "CF.DEBUG", + "CF.EXISTS", + "CF.INFO", + "CF.MEXISTS", + "CF.SCANDUMP", + "CMS.INFO", + "CMS.QUERY", + "DUMP", + "EXPIRETIME", + "FT.AGGREGATE", + "FT.ALIASADD", + "FT.ALIASDEL", + "FT.ALIASUPDATE", + "FT.CURSOR", + "FT.EXPLAIN", + "FT.EXPLAINCLI", + "FT.GET", + "FT.INFO", + "FT.MGET", + "FT.PROFILE", + "FT.SEARCH", + "FT.SPELLCHECK", + "FT.SUGGET", + "FT.SUGLEN", + "FT.SYNDUMP", + "FT.TAGVALS", + "FT._ALIASADDIFNX", + "FT._ALIASDELIFX", + "HRANDFIELD", + "JSON.DEBUG", + "PEXPIRETIME", + "PFCOUNT", + "PTTL", + "SRANDMEMBER", + "TDIGEST.BYRANK", + "TDIGEST.BYREVRANK", + "TDIGEST.CDF", + "TDIGEST.INFO", + "TDIGEST.MAX", + "TDIGEST.MIN", + "TDIGEST.QUANTILE", + "TDIGEST.RANK", + "TDIGEST.REVRANK", + "TDIGEST.TRIMMED_MEAN", + "TOPK.INFO", + "TOPK.LIST", + "TOPK.QUERY", + "TOUCH", + "TTL", +] + + +DEFAULT_WHITELIST = [ + "BITCOUNT", + "BITFIELD_RO", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUSBYMEMBER_RO", + "GEORADIUS_RO", + "GEOSEARCH", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "JSON.ARRINDEX", + "JSON.ARRLEN", + "JSON.GET", + "JSON.MGET", + "JSON.OBJKEYS", + "JSON.OBJLEN", + "JSON.RESP", + "JSON.STRLEN", + "JSON.TYPE", + "LCS", + "LINDEX", + "LLEN", + "LPOS", + "LRANGE", + "MGET", + "SCARD", + "SDIFF", + "SINTER", + "SINTERCARD", + "SISMEMBER", + "SMEMBERS", + "SMISMEMBER", + "SORT_RO", + "STRLEN", + "SUBSTR", + "SUNION", + "TS.GET", + "TS.INFO", + "TS.RANGE", + "TS.REVRANGE", + "TYPE", + "XLEN", + "XPENDING", + "XRANGE", + "XREAD", + "XREVRANGE", + "ZCARD", + "ZCOUNT", + "ZDIFF", + "ZINTER", + "ZINTERCARD", + "ZLEXCOUNT", + "ZMSCORE", + "ZRANGE", + "ZRANGEBYLEX", + "ZRANGEBYSCORE", + "ZRANK", + "ZREVRANGE", + "ZREVRANGEBYLEX", + "ZREVRANGEBYSCORE", + "ZREVRANK", + "ZSCORE", + "ZUNION", +] + +_RESPONSE = "response" +_KEYS = "keys" +_CTIME = "ctime" +_ACCESS_COUNT = "access_count" + + +class EvictionPolicy(Enum): + LRU = "lru" + LFU = "lfu" + RANDOM = "random" + + +class _LocalChace: + """ + A caching mechanism for storing redis commands and their responses. + + Args: + max_size (int): The maximum number of commands to be stored in the cache. + ttl (int): The time-to-live for each command in seconds. + eviction_policy (EvictionPolicy): The eviction policy to use for removing commands when the cache is full. + + Attributes: + max_size (int): The maximum number of commands to be stored in the cache. + ttl (int): The time-to-live for each command in seconds. + eviction_policy (EvictionPolicy): The eviction policy used for cache management. + cache (OrderedDict): The ordered dictionary to store commands and their metadata. + key_commands_map (defaultdict): A mapping of keys to the set of commands that use each key. + commands_ttl_list (list): A list to keep track of the commands in the order they were added. # noqa + """ + + def __init__( + self, max_size: int, ttl: int, eviction_policy: EvictionPolicy, **kwargs + ): + self.max_size = max_size + self.ttl = ttl + self.eviction_policy = eviction_policy + self.cache = OrderedDict() + self.key_commands_map = defaultdict(set) + self.commands_ttl_list = [] + + def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): + """ + Set a redis command and its response in the cache. + + Args: + command (str): The redis command. + response (ResponseT): The response associated with the command. + keys_in_command (List[KeyT]): The list of keys used in the command. + """ + if len(self.cache) >= self.max_size: + self._evict() + self.cache[command] = { + _RESPONSE: response, + _KEYS: keys_in_command, + _CTIME: time.monotonic(), + _ACCESS_COUNT: 0, # Used only for LFU + } + self._update_key_commands_map(keys_in_command, command) + self.commands_ttl_list.append(command) + + def get(self, command: str) -> ResponseT: + """ + Get the response for a redis command from the cache. + + Args: + command (str): The redis command. + + Returns: + ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa + """ + if command in self.cache: + if self._is_expired(command): + self.delete(command) + self._update_access(command) + return self.cache[command]["response"] + + def delete(self, command: str): + """ + Delete a redis command and its metadata from the cache. + + Args: + command (str): The redis command to be deleted. + """ + if command in self.cache: + keys_in_command = self.cache[command].get("keys") + self._del_key_commands_map(keys_in_command, command) + self.commands_ttl_list.remove(command) + del self.cache[command] + + def delete_many(self, commands): + pass + + def flush(self): + """Clear the entire cache, removing all redis commands and metadata.""" + self.cache.clear() + self.key_commands_map.clear() + self.commands_ttl_list = [] + + def _is_expired(self, command: str) -> bool: + """ + Check if a redis command has expired based on its time-to-live. + + Args: + command (str): The redis command. + + Returns: + bool: True if the command has expired, False otherwise. + """ + if self.ttl == 0: + return False + return time.monotonic() - self.cache[command]["ctime"] > self.ttl + + def _update_access(self, command: str): + """ + Update the access information for a redis command based on the eviction policy. + + Args: + command (str): The redis command. + """ + if self.eviction_policy == EvictionPolicy.LRU: + self.cache.move_to_end(command) + elif self.eviction_policy == EvictionPolicy.LFU: + self.cache[command]["access_count"] = ( + self.cache.get(command, {}).get("access_count", 0) + 1 + ) + self.cache.move_to_end(command) + elif self.eviction_policy == EvictionPolicy.RANDOM: + pass # Random eviction doesn't require updates + + def _evict(self): + """Evict a redis command from the cache based on the eviction policy.""" + if self._is_expired(self.commands_ttl_list[0]): + self.delete(self.commands_ttl_list[0]) + elif self.eviction_policy == EvictionPolicy.LRU: + self.cache.popitem(last=False) + elif self.eviction_policy == EvictionPolicy.LFU: + min_access_command = min( + self.cache, key=lambda k: self.cache[k].get("access_count", 0) + ) + self.cache.pop(min_access_command) + elif self.eviction_policy == EvictionPolicy.RANDOM: + random_command = random.choice(list(self.cache.keys())) + self.cache.pop(random_command) + + def _update_key_commands_map(self, keys: List[KeyT], command: str): + """ + Update the key_commands_map with command that uses the keys. + + Args: + keys (List[KeyT]): The list of keys used in the command. + command (str): The redis command. + """ + for key in keys: + self.key_commands_map[key].add(command) + + def _del_key_commands_map(self, keys: List[KeyT], command: str): + """ + Remove a redis command from the key_commands_map. + + Args: + keys (List[KeyT]): The list of keys used in the redis command. + command (str): The redis command. + """ + for key in keys: + self.key_commands_map[key].remove(command) + + def invalidate(self, key: KeyT): + """ + Invalidate (delete) all redis commands associated with a specific key. + + Args: + key (KeyT): The key to be invalidated. + """ + if key not in self.key_commands_map: + return + for command in self.key_commands_map[key]: + self.delete(command) diff --git a/redis/client.py b/redis/client.py index cb91c7a088..2ff3706d16 100755 --- a/redis/client.py +++ b/redis/client.py @@ -13,6 +13,12 @@ _RedisCallbacksRESP3, bool_ok, ) +from redis.cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalChace, +) from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -32,6 +38,7 @@ ) from redis.lock import Lock from redis.retry import Retry +from redis.typing import KeysT, ResponseT from redis.utils import ( HIREDIS_AVAILABLE, _set_info_logger, @@ -203,6 +210,13 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + cache_enable: bool = False, + client_cache: Optional[_LocalChace] = None, + cache_max_size: int = 100, + cache_ttl: int = 0, + cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_blacklist: List[str] = DEFAULT_BLACKLIST, + cache_whitelist: List[str] = DEFAULT_WHITELIST, ) -> None: """ Initialize a new Redis client. @@ -310,6 +324,15 @@ def __init__( else: self.response_callbacks.update(_RedisCallbacksRESP2) + self.client_cache = client_cache + if cache_enable: + self.client_cache = _LocalChace( + cache_max_size, cache_ttl, cache_eviction_policy + ) + if self.client_cache is not None: + self.cache_blacklist = cache_blacklist + self.cache_whitelist = cache_whitelist + def __repr__(self) -> str: return f"{type(self).__name__}<{repr(self.connection_pool)}>" @@ -525,23 +548,63 @@ def _disconnect_raise(self, conn, error): ): raise error + def _get_from_local_cache(self, command: str): + """ + If the command is in the local cache, return the response + """ + if ( + self.client_cache is None + or command[0] in self.cache_blacklist + or command[0] not in self.cache_whitelist + ): + return None + return self.client_cache.get(command) + + def _add_to_local_cache(self, command: str, response: ResponseT, keys: List[KeysT]): + """ + Add the command and response to the local cache if the command + is allowed to be cached + """ + if ( + self.client_cache is not None + and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) + and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + ): + self.client_cache.set(command, response, keys) + + def delete_from_local_cache(self, command: str): + """ + Delete the command from the local cache + """ + try: + self.client_cache.delete(command) + except AttributeError: + pass + # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): """Execute a command and return a parsed response""" - pool = self.connection_pool command_name = args[0] - conn = self.connection or pool.get_connection(command_name, **options) + keys = options.pop("keys", None) + response_from_cache = self._get_from_local_cache(args) + if response_from_cache is not None: + return response_from_cache + else: + pool = self.connection_pool + conn = self.connection or pool.get_connection(command_name, **options) - try: - return conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - finally: - if not self.connection: - pool.release(conn) + try: + response = conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + self._add_to_local_cache(args, response, keys) + return response + finally: + if not self.connection: + pool.release(conn) def parse_response(self, connection, command_name, **options): """Parses a response from the Redis server""" @@ -1253,6 +1316,7 @@ def multi(self) -> None: self.explicit_transaction = True def execute_command(self, *args, **kwargs): + kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/cluster.py b/redis/cluster.py index 873d586c4a..4de11b4e8c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1060,6 +1060,7 @@ def execute_command(self, *args, **kwargs): list dict """ + kwargs.pop("keys", None) # the keys are used only for client side caching target_nodes_specified = False is_default_node = False target_nodes = None @@ -1962,6 +1963,7 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ + kwargs.pop("keys", None) # the keys are used only for client side caching return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index 14b8741443..8637f6c247 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -23,6 +23,7 @@ KeysT, KeyT, PatternT, + ResponseT, ) from .core import ( @@ -40,7 +41,6 @@ ManagementCommands, ModuleCommands, PubSubCommands, - ResponseT, ScriptCommands, ) from .helpers import list_or_args diff --git a/redis/commands/core.py b/redis/commands/core.py index e73553e47e..f97724d030 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5,7 +5,6 @@ import warnings from typing import ( TYPE_CHECKING, - Any, AsyncIterator, Awaitable, Callable, @@ -37,6 +36,7 @@ KeysT, KeyT, PatternT, + ResponseT, ScriptTextT, StreamIdT, TimeoutSecT, @@ -49,8 +49,6 @@ from redis.asyncio.client import Redis as AsyncRedis from redis.client import Redis -ResponseT = Union[Awaitable, Any] - class ACLCommands(CommandsProtocol): """ @@ -1590,7 +1588,7 @@ def bitcount( raise DataError("Both start and end must be specified") if mode is not None: params.append(mode) - return self.execute_command("BITCOUNT", *params) + return self.execute_command("BITCOUNT", *params, keys=[key]) def bitfield( self: Union["Redis", "AsyncRedis"], @@ -1626,7 +1624,7 @@ def bitfield_ro( items = items or [] for encoding, offset in items: params.extend(["GET", encoding, offset]) - return self.execute_command("BITFIELD_RO", *params) + return self.execute_command("BITFIELD_RO", *params, keys=[key]) def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseT: """ @@ -1666,7 +1664,7 @@ def bitpos( if mode is not None: params.append(mode) - return self.execute_command("BITPOS", *params) + return self.execute_command("BITPOS", *params, keys=[key]) def copy( self, @@ -1733,7 +1731,7 @@ def exists(self, *names: KeyT) -> ResponseT: For more information see https://redis.io/commands/exists """ - return self.execute_command("EXISTS", *names) + return self.execute_command("EXISTS", *names, keys=names) __contains__ = exists @@ -1826,7 +1824,7 @@ def get(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/get """ - return self.execute_command("GET", name) + return self.execute_command("GET", name, keys=[name]) def getdel(self, name: KeyT) -> ResponseT: """ @@ -1920,7 +1918,7 @@ def getbit(self, name: KeyT, offset: int) -> ResponseT: For more information see https://redis.io/commands/getbit """ - return self.execute_command("GETBIT", name, offset) + return self.execute_command("GETBIT", name, offset, keys=[name]) def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ @@ -1929,7 +1927,7 @@ def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: For more information see https://redis.io/commands/getrange """ - return self.execute_command("GETRANGE", key, start, end) + return self.execute_command("GETRANGE", key, start, end, keys=[key]) def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ @@ -2012,6 +2010,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: options = {} if not args: options[EMPTY_RESPONSE] = [] + options["keys"] = keys return self.execute_command("MGET", *args, **options) def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: @@ -2458,14 +2457,14 @@ def strlen(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/strlen """ - return self.execute_command("STRLEN", name) + return self.execute_command("STRLEN", name, keys=[name]) def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ - return self.execute_command("SUBSTR", name, start, end) + return self.execute_command("SUBSTR", name, start, end, keys=[name]) def touch(self, *args: KeyT) -> ResponseT: """ @@ -2490,7 +2489,7 @@ def type(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/type """ - return self.execute_command("TYPE", name) + return self.execute_command("TYPE", name, keys=[name]) def watch(self, *names: KeyT) -> None: """ @@ -2543,7 +2542,7 @@ def lcs( pieces.extend(["MINMATCHLEN", minmatchlen]) if withmatchlen: pieces.append("WITHMATCHLEN") - return self.execute_command("LCS", *pieces) + return self.execute_command("LCS", *pieces, keys=[key1, key2]) class AsyncBasicKeyCommands(BasicKeyCommands): @@ -2682,7 +2681,7 @@ def lindex( For more information see https://redis.io/commands/lindex """ - return self.execute_command("LINDEX", name, index) + return self.execute_command("LINDEX", name, index, keys=[name]) def linsert( self, name: str, where: str, refvalue: str, value: str @@ -2704,7 +2703,7 @@ def llen(self, name: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/llen """ - return self.execute_command("LLEN", name) + return self.execute_command("LLEN", name, keys=[name]) def lpop( self, @@ -2751,7 +2750,7 @@ def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list For more information see https://redis.io/commands/lrange """ - return self.execute_command("LRANGE", name, start, end) + return self.execute_command("LRANGE", name, start, end, keys=[name]) def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: """ @@ -2874,7 +2873,7 @@ def lpos( if maxlen is not None: pieces.extend(["MAXLEN", maxlen]) - return self.execute_command("LPOS", *pieces) + return self.execute_command("LPOS", *pieces, keys=[name]) def sort( self, @@ -2946,6 +2945,7 @@ def sort( ) options = {"groups": len(get) if groups else None} + options["keys"] = [name] return self.execute_command("SORT", *pieces, **options) def sort_ro( @@ -3319,7 +3319,7 @@ def scard(self, name: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/scard """ - return self.execute_command("SCARD", name) + return self.execute_command("SCARD", name, keys=[name]) def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: """ @@ -3328,7 +3328,7 @@ def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: For more information see https://redis.io/commands/sdiff """ args = list_or_args(keys, args) - return self.execute_command("SDIFF", *args) + return self.execute_command("SDIFF", *args, keys=args) def sdiffstore( self, dest: str, keys: List, *args: List @@ -3349,7 +3349,7 @@ def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: For more information see https://redis.io/commands/sinter """ args = list_or_args(keys, args) - return self.execute_command("SINTER", *args) + return self.execute_command("SINTER", *args, keys=args) def sintercard( self, numkeys: int, keys: List[str], limit: int = 0 @@ -3364,7 +3364,7 @@ def sintercard( For more information see https://redis.io/commands/sintercard """ args = [numkeys, *keys, "LIMIT", limit] - return self.execute_command("SINTERCARD", *args) + return self.execute_command("SINTERCARD", *args, keys=keys) def sinterstore( self, dest: str, keys: List, *args: List @@ -3388,7 +3388,7 @@ def sismember( For more information see https://redis.io/commands/sismember """ - return self.execute_command("SISMEMBER", name, value) + return self.execute_command("SISMEMBER", name, value, keys=[name]) def smembers(self, name: str) -> Union[Awaitable[Set], Set]: """ @@ -3396,7 +3396,7 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: For more information see https://redis.io/commands/smembers """ - return self.execute_command("SMEMBERS", name) + return self.execute_command("SMEMBERS", name, keys=[name]) def smismember( self, name: str, values: List, *args: List @@ -3413,7 +3413,7 @@ def smismember( For more information see https://redis.io/commands/smismember """ args = list_or_args(values, args) - return self.execute_command("SMISMEMBER", name, *args) + return self.execute_command("SMISMEMBER", name, *args, keys=[name]) def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: """ @@ -3462,7 +3462,7 @@ def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: For more information see https://redis.io/commands/sunion """ args = list_or_args(keys, args) - return self.execute_command("SUNION", *args) + return self.execute_command("SUNION", *args, keys=args) def sunionstore( self, dest: str, keys: List, *args: List @@ -3820,7 +3820,7 @@ def xlen(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/xlen """ - return self.execute_command("XLEN", name) + return self.execute_command("XLEN", name, keys=[name]) def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: """ @@ -3830,7 +3830,7 @@ def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: For more information see https://redis.io/commands/xpending """ - return self.execute_command("XPENDING", name, groupname) + return self.execute_command("XPENDING", name, groupname, keys=[name]) def xpending_range( self, @@ -3919,7 +3919,7 @@ def xrange( pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XRANGE", name, *pieces) + return self.execute_command("XRANGE", name, *pieces, keys=[name]) def xread( self, @@ -3957,7 +3957,7 @@ def xread( keys, values = zip(*streams.items()) pieces.extend(keys) pieces.extend(values) - return self.execute_command("XREAD", *pieces) + return self.execute_command("XREAD", *pieces, keys=keys) def xreadgroup( self, @@ -4036,7 +4036,7 @@ def xrevrange( pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XREVRANGE", name, *pieces) + return self.execute_command("XREVRANGE", name, *pieces, keys=[name]) def xtrim( self, @@ -4175,7 +4175,7 @@ def zcard(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/zcard """ - return self.execute_command("ZCARD", name) + return self.execute_command("ZCARD", name, keys=[name]) def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ @@ -4184,7 +4184,7 @@ def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: For more information see https://redis.io/commands/zcount """ - return self.execute_command("ZCOUNT", name, min, max) + return self.execute_command("ZCOUNT", name, min, max, keys=[name]) def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: """ @@ -4196,7 +4196,7 @@ def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: pieces = [len(keys), *keys] if withscores: pieces.append("WITHSCORES") - return self.execute_command("ZDIFF", *pieces) + return self.execute_command("ZDIFF", *pieces, keys=keys) def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: """ @@ -4264,7 +4264,7 @@ def zintercard( For more information see https://redis.io/commands/zintercard """ args = [numkeys, *keys, "LIMIT", limit] - return self.execute_command("ZINTERCARD", *args) + return self.execute_command("ZINTERCARD", *args, keys=keys) def zlexcount(self, name, min, max): """ @@ -4273,7 +4273,7 @@ def zlexcount(self, name, min, max): For more information see https://redis.io/commands/zlexcount """ - return self.execute_command("ZLEXCOUNT", name, min, max) + return self.execute_command("ZLEXCOUNT", name, min, max, keys=[name]) def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: """ @@ -4456,6 +4456,7 @@ def _zrange( if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = [name] return self.execute_command(*pieces, **options) def zrange( @@ -4544,6 +4545,7 @@ def zrevrange( if withscores: pieces.append(b"WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = name return self.execute_command(*pieces, **options) def zrangestore( @@ -4618,7 +4620,7 @@ def zrangebylex( pieces = ["ZRANGEBYLEX", name, min, max] if start is not None and num is not None: pieces.extend([b"LIMIT", start, num]) - return self.execute_command(*pieces) + return self.execute_command(*pieces, keys=[name]) def zrevrangebylex( self, @@ -4642,7 +4644,7 @@ def zrevrangebylex( pieces = ["ZREVRANGEBYLEX", name, max, min] if start is not None and num is not None: pieces.extend(["LIMIT", start, num]) - return self.execute_command(*pieces) + return self.execute_command(*pieces, keys=[name]) def zrangebyscore( self, @@ -4676,6 +4678,7 @@ def zrangebyscore( if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = [name] return self.execute_command(*pieces, **options) def zrevrangebyscore( @@ -4710,6 +4713,7 @@ def zrevrangebyscore( if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = [name] return self.execute_command(*pieces, **options) def zrank( @@ -4727,8 +4731,8 @@ def zrank( For more information see https://redis.io/commands/zrank """ if withscore: - return self.execute_command("ZRANK", name, value, "WITHSCORE") - return self.execute_command("ZRANK", name, value) + return self.execute_command("ZRANK", name, value, "WITHSCORE", keys=[name]) + return self.execute_command("ZRANK", name, value, keys=[name]) def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -4786,8 +4790,10 @@ def zrevrank( For more information see https://redis.io/commands/zrevrank """ if withscore: - return self.execute_command("ZREVRANK", name, value, "WITHSCORE") - return self.execute_command("ZREVRANK", name, value) + return self.execute_command( + "ZREVRANK", name, value, "WITHSCORE", keys=[name] + ) + return self.execute_command("ZREVRANK", name, value, keys=[name]) def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: """ @@ -4795,7 +4801,7 @@ def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: For more information see https://redis.io/commands/zscore """ - return self.execute_command("ZSCORE", name, value) + return self.execute_command("ZSCORE", name, value, keys=[name]) def zunion( self, @@ -4842,7 +4848,7 @@ def zmscore(self, key: KeyT, members: List[str]) -> ResponseT: if not members: raise DataError("ZMSCORE members must be a non-empty list") pieces = [key] + members - return self.execute_command("ZMSCORE", *pieces) + return self.execute_command("ZMSCORE", *pieces, keys=[key]) def _zaggregate( self, @@ -4872,6 +4878,7 @@ def _zaggregate( raise DataError("aggregate can be sum, min or max.") if options.get("withscores", False): pieces.append(b"WITHSCORES") + options["keys"] = keys return self.execute_command(*pieces, **options) @@ -4933,7 +4940,7 @@ def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: For more information see https://redis.io/commands/hexists """ - return self.execute_command("HEXISTS", name, key) + return self.execute_command("HEXISTS", name, key, keys=[name]) def hget( self, name: str, key: str @@ -4943,7 +4950,7 @@ def hget( For more information see https://redis.io/commands/hget """ - return self.execute_command("HGET", name, key) + return self.execute_command("HGET", name, key, keys=[name]) def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: """ @@ -4951,7 +4958,7 @@ def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: For more information see https://redis.io/commands/hgetall """ - return self.execute_command("HGETALL", name) + return self.execute_command("HGETALL", name, keys=[name]) def hincrby( self, name: str, key: str, amount: int = 1 @@ -4979,7 +4986,7 @@ def hkeys(self, name: str) -> Union[Awaitable[List], List]: For more information see https://redis.io/commands/hkeys """ - return self.execute_command("HKEYS", name) + return self.execute_command("HKEYS", name, keys=[name]) def hlen(self, name: str) -> Union[Awaitable[int], int]: """ @@ -4987,7 +4994,7 @@ def hlen(self, name: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/hlen """ - return self.execute_command("HLEN", name) + return self.execute_command("HLEN", name, keys=[name]) def hset( self, @@ -5054,7 +5061,7 @@ def hmget(self, name: str, keys: List, *args: List) -> Union[Awaitable[List], Li For more information see https://redis.io/commands/hmget """ args = list_or_args(keys, args) - return self.execute_command("HMGET", name, *args) + return self.execute_command("HMGET", name, *args, keys=[name]) def hvals(self, name: str) -> Union[Awaitable[List], List]: """ @@ -5062,7 +5069,7 @@ def hvals(self, name: str) -> Union[Awaitable[List], List]: For more information see https://redis.io/commands/hvals """ - return self.execute_command("HVALS", name) + return self.execute_command("HVALS", name, keys=[name]) def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: """ @@ -5071,7 +5078,7 @@ def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/hstrlen """ - return self.execute_command("HSTRLEN", name, key) + return self.execute_command("HSTRLEN", name, key, keys=[name]) AsyncHashCommands = HashCommands @@ -5464,7 +5471,7 @@ def geodist( raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) - return self.execute_command("GEODIST", *pieces) + return self.execute_command("GEODIST", *pieces, keys=[name]) def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -5473,7 +5480,7 @@ def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: For more information see https://redis.io/commands/geohash """ - return self.execute_command("GEOHASH", name, *values) + return self.execute_command("GEOHASH", name, *values, keys=[name]) def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -5483,7 +5490,7 @@ def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: For more information see https://redis.io/commands/geopos """ - return self.execute_command("GEOPOS", name, *values) + return self.execute_command("GEOPOS", name, *values, keys=[name]) def georadius( self, @@ -5823,6 +5830,8 @@ def _geosearchgeneric( if kwargs[arg_name]: pieces.append(byte_repr) + kwargs["keys"] = [args[0] if command == "GEOSEARCH" else args[1]] + return self.execute_command(command, *pieces, **kwargs) diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 0f92e0d6c9..ef0cb205a5 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -49,7 +49,7 @@ def arrindex( if stop is not None: pieces.append(stop) - return self.execute_command("JSON.ARRINDEX", *pieces) + return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name]) def arrinsert( self, name: str, path: str, index: int, *args: List[JsonType] @@ -72,7 +72,7 @@ def arrlen( For more information see `JSON.ARRLEN `_. """ # noqa - return self.execute_command("JSON.ARRLEN", name, str(path)) + return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name]) def arrpop( self, @@ -102,14 +102,14 @@ def type(self, name: str, path: Optional[str] = Path.root_path()) -> List[str]: For more information see `JSON.TYPE `_. """ # noqa - return self.execute_command("JSON.TYPE", name, str(path)) + return self.execute_command("JSON.TYPE", name, str(path), keys=[name]) def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List: """Return the JSON value under ``path`` at key ``name``. For more information see `JSON.RESP `_. """ # noqa - return self.execute_command("JSON.RESP", name, str(path)) + return self.execute_command("JSON.RESP", name, str(path), keys=[name]) def objkeys( self, name: str, path: Optional[str] = Path.root_path() @@ -119,7 +119,7 @@ def objkeys( For more information see `JSON.OBJKEYS `_. """ # noqa - return self.execute_command("JSON.OBJKEYS", name, str(path)) + return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name]) def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int: """Return the length of the dictionary JSON value under ``path`` at key @@ -127,7 +127,7 @@ def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int: For more information see `JSON.OBJLEN `_. """ # noqa - return self.execute_command("JSON.OBJLEN", name, str(path)) + return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name]) def numincrby(self, name: str, path: str, number: int) -> str: """Increment the numeric (integer or floating point) JSON value under @@ -197,7 +197,7 @@ def get( # Handle case where key doesn't exist. The JSONDecoder would raise a # TypeError exception since it can't decode None try: - return self.execute_command("JSON.GET", *pieces) + return self.execute_command("JSON.GET", *pieces, keys=[name]) except TypeError: return None @@ -211,7 +211,7 @@ def mget(self, keys: List[str], path: str) -> List[JsonType]: pieces = [] pieces += keys pieces.append(str(path)) - return self.execute_command("JSON.MGET", *pieces) + return self.execute_command("JSON.MGET", *pieces, keys=keys) def set( self, @@ -364,7 +364,7 @@ def strlen(self, name: str, path: Optional[str] = None) -> List[Union[int, None] pieces = [name] if path is not None: pieces.append(str(path)) - return self.execute_command("JSON.STRLEN", *pieces) + return self.execute_command("JSON.STRLEN", *pieces, keys=[name]) def toggle( self, name: str, path: Optional[str] = Path.root_path() diff --git a/redis/commands/timeseries/commands.py b/redis/commands/timeseries/commands.py index 13e3cdf498..1cb183d087 100644 --- a/redis/commands/timeseries/commands.py +++ b/redis/commands/timeseries/commands.py @@ -425,7 +425,7 @@ def range( bucket_timestamp, empty, ) - return self.execute_command(RANGE_CMD, *params) + return self.execute_command(RANGE_CMD, *params, keys=[key]) def revrange( self, @@ -497,7 +497,7 @@ def revrange( bucket_timestamp, empty, ) - return self.execute_command(REVRANGE_CMD, *params) + return self.execute_command(REVRANGE_CMD, *params, keys=[key]) def __mrange_params( self, @@ -721,7 +721,7 @@ def get(self, key: KeyT, latest: Optional[bool] = False): """ # noqa params = [key] self._append_latest(params, latest) - return self.execute_command(GET_CMD, *params) + return self.execute_command(GET_CMD, *params, keys=[key]) def mget( self, @@ -761,7 +761,7 @@ def info(self, key: KeyT): For more information: https://redis.io/commands/ts.info/ """ # noqa - return self.execute_command(INFO_CMD, key) + return self.execute_command(INFO_CMD, key, keys=[key]) def queryindex(self, filters: List[str]): """# noqa diff --git a/redis/commands/timeseries/utils.py b/redis/commands/timeseries/utils.py index c49b040271..12ed656277 100644 --- a/redis/commands/timeseries/utils.py +++ b/redis/commands/timeseries/utils.py @@ -5,7 +5,7 @@ def list_to_dict(aList): return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))} -def parse_range(response): +def parse_range(response, **kwargs): """Parse range response. Used by TS.RANGE and TS.REVRANGE.""" return [tuple((r[0], float(r[1]))) for r in response] diff --git a/redis/sentinel.py b/redis/sentinel.py index 41f308d1ee..a1ae5c5275 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -244,6 +244,7 @@ def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ + kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") diff --git a/redis/typing.py b/redis/typing.py index 56a1e99ba7..d1cd5568a3 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -33,6 +33,7 @@ PatternT = _StringLikeT # Patterns matched against keys, fields etc FieldT = EncodableT # Fields within hash tables, streams and geo commands KeysT = Union[KeyT, Iterable[KeyT]] +ResponseT = Union[Awaitable, Any] ChannelT = _StringLikeT GroupT = _StringLikeT # Consumer group ConsumerT = _StringLikeT # Consumer name