From ee1d7cfc53e34c14f2f9d89163f3abe57c431d80 Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Sun, 18 Feb 2024 17:45:50 +0200 Subject: [PATCH] Python: Allow chaining function calls on transaction (#987) --- CHANGELOG.md | 2 +- .../glide/async_commands/transaction.py | 282 ++++++++++-------- python/python/tests/test_transaction.py | 11 + 3 files changed, 167 insertions(+), 128 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c854978839..191325aaf6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ #### Features * Python, Node: Added support in Lua Scripts ([#775](https://github.com/aws/glide-for-redis/pull/775), [#860](https://github.com/aws/glide-for-redis/pull/860)) -* Node: Allow chaining function calls on transaction. ([#902](https://github.com/aws/glide-for-redis/pull/902)) +* Python, Node: Allow chaining function calls on transaction. ([#902](https://github.com/aws/glide-for-redis/pull/902)), ([#987](https://github.com/aws/glide-for-redis/pull/987)) #### Fixes * Core: Fixed `Connection Refused` error not to close the client ([#872](https://github.com/aws/glide-for-redis/pull/872)) diff --git a/python/python/glide/async_commands/transaction.py b/python/python/glide/async_commands/transaction.py index 9a4f5b101e..7f0b638435 100644 --- a/python/python/glide/async_commands/transaction.py +++ b/python/python/glide/async_commands/transaction.py @@ -1,7 +1,7 @@ # Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 import threading -from typing import List, Mapping, Optional, Tuple, Union +from typing import List, Mapping, Optional, Tuple, TypeVar, Union from glide.async_commands.core import ( ConditionalChange, @@ -14,6 +14,8 @@ ) from glide.protobuf.redis_request_pb2 import RequestType +TTransaction = TypeVar("TTransaction", bound="BaseTransaction") + class BaseTransaction: """ @@ -25,8 +27,7 @@ class BaseTransaction: Example: transaction = BaseTransaction() - >>> transaction.set("key", "value") - >>> transaction.get("key") + >>> transaction.set("key", "value").get("key") >>> await client.exec(transaction) [OK , "value"] """ @@ -35,18 +36,21 @@ def __init__(self) -> None: self.commands: List[Tuple[RequestType.ValueType, List[str]]] = [] self.lock = threading.Lock() - def append_command(self, request_type: RequestType.ValueType, args: List[str]): + def append_command( + self: TTransaction, request_type: RequestType.ValueType, args: List[str] + ) -> TTransaction: self.lock.acquire() try: self.commands.append((request_type, args)) finally: self.lock.release() + return self def clear(self): with self.lock: self.commands.clear() - def get(self, key: str): + def get(self: TTransaction, key: str) -> TTransaction: """ Get the value associated with the given key, or null if no such value exists. See https://redis.io/commands/get/ for details. @@ -57,16 +61,16 @@ def get(self, key: str): Command response: Optional[str]: If the key exists, returns the value of the key as a string. Otherwise, return None. """ - self.append_command(RequestType.GetString, [key]) + return self.append_command(RequestType.GetString, [key]) def set( - self, + self: TTransaction, key: str, value: str, conditional_set: Union[ConditionalChange, None] = None, expiry: Union[ExpirySet, None] = None, return_old_value: bool = False, - ): + ) -> TTransaction: """ Set the given key with the given value. Return value is dependent on the passed options. See https://redis.io/commands/set/ for details. @@ -102,9 +106,9 @@ def set( args.append("GET") if expiry is not None: args.extend(expiry.get_cmd_args()) - self.append_command(RequestType.SetString, args) + return self.append_command(RequestType.SetString, args) - def custom_command(self, command_args: List[str]): + def custom_command(self: TTransaction, command_args: List[str]) -> TTransaction: """ Executes a single command, without checking inputs. @remarks - This function should only be used for single-response commands. Commands that don't return response (such as SUBSCRIBE), or that return potentially more than a single response (such as XREAD), or that change the client's behavior (such as entering pub/sub mode on RESP2 connections) shouldn't be called using this function. @@ -119,12 +123,12 @@ def custom_command(self, command_args: List[str]): Command response: TResult: The returning value depends on the executed command. """ - self.append_command(RequestType.CustomCommand, command_args) + return self.append_command(RequestType.CustomCommand, command_args) def info( - self, + self: TTransaction, sections: Optional[List[InfoSection]] = None, - ): + ) -> TTransaction: """ Get information and statistics about the Redis server. See https://redis.io/commands/info/ for details. @@ -137,9 +141,9 @@ def info( str: Returns a string containing the information for the sections requested. """ args = [section.value for section in sections] if sections else [] - self.append_command(RequestType.Info, args) + return self.append_command(RequestType.Info, args) - def delete(self, keys: List[str]): + def delete(self: TTransaction, keys: List[str]) -> TTransaction: """ Delete one or more keys from the database. A key is ignored if it does not exist. See https://redis.io/commands/del/ for details. @@ -150,9 +154,9 @@ def delete(self, keys: List[str]): Command response: int: The number of keys that were deleted. """ - self.append_command(RequestType.Del, keys) + return self.append_command(RequestType.Del, keys) - def config_get(self, parameters: List[str]): + def config_get(self: TTransaction, parameters: List[str]) -> TTransaction: """ Get the values of configuration parameters. See https://redis.io/commands/config-get/ for details. @@ -163,9 +167,11 @@ def config_get(self, parameters: List[str]): Command response: Dict[str, str]: A dictionary of values corresponding to the configuration parameters. """ - self.append_command(RequestType.ConfigGet, parameters) + return self.append_command(RequestType.ConfigGet, parameters) - def config_set(self, parameters_map: Mapping[str, str]): + def config_set( + self: TTransaction, parameters_map: Mapping[str, str] + ) -> TTransaction: """ Set configuration parameters to the specified values. See https://redis.io/commands/config-set/ for details. @@ -180,9 +186,9 @@ def config_set(self, parameters_map: Mapping[str, str]): parameters: List[str] = [] for pair in parameters_map.items(): parameters.extend(pair) - self.append_command(RequestType.ConfigSet, parameters) + return self.append_command(RequestType.ConfigSet, parameters) - def config_resetstat(self): + def config_resetstat(self: TTransaction) -> TTransaction: """ Resets the statistics reported by Redis using the INFO and LATENCY HISTOGRAM commands. See https://redis.io/commands/config-resetstat/ for details. @@ -190,9 +196,9 @@ def config_resetstat(self): Command response: OK: a simple OK response. """ - self.append_command(RequestType.ConfigResetStat, []) + return self.append_command(RequestType.ConfigResetStat, []) - def mset(self, key_value_map: Mapping[str, str]): + def mset(self: TTransaction, key_value_map: Mapping[str, str]) -> TTransaction: """ Set multiple keys to multiple values in a single atomic operation. See https://redis.io/commands/mset/ for more details. @@ -206,9 +212,9 @@ def mset(self, key_value_map: Mapping[str, str]): parameters: List[str] = [] for pair in key_value_map.items(): parameters.extend(pair) - self.append_command(RequestType.MSet, parameters) + return self.append_command(RequestType.MSet, parameters) - def mget(self, keys: List[str]): + def mget(self: TTransaction, keys: List[str]) -> TTransaction: """ Retrieve the values of multiple keys. See https://redis.io/commands/mget/ for more details. @@ -220,9 +226,9 @@ def mget(self, keys: List[str]): List[str]: A list of values corresponding to the provided keys. If a key is not found, its corresponding value in the list will be None. """ - self.append_command(RequestType.MGet, keys) + return self.append_command(RequestType.MGet, keys) - def config_rewrite(self): + def config_rewrite(self: TTransaction) -> TTransaction: """ Rewrite the configuration file with the current configuration. See https://redis.io/commands/config-rewrite/ for details. @@ -230,9 +236,9 @@ def config_rewrite(self): Command response: OK: OK is returned when the configuration was rewritten properly. Otherwise, the transaction fails with an error. """ - self.append_command(RequestType.ConfigRewrite, []) + return self.append_command(RequestType.ConfigRewrite, []) - def client_id(self): + def client_id(self: TTransaction) -> TTransaction: """ Returns the current connection id. See https://redis.io/commands/client-id/ for more information. @@ -240,9 +246,9 @@ def client_id(self): Command response: int: the id of the client. """ - self.append_command(RequestType.ClientId, []) + return self.append_command(RequestType.ClientId, []) - def incr(self, key: str): + def incr(self: TTransaction, key: str) -> TTransaction: """ Increments the number stored at `key` by one. If `key` does not exist, it is set to 0 before performing the @@ -255,9 +261,9 @@ def incr(self, key: str): Command response: int: the value of `key` after the increment. """ - self.append_command(RequestType.Incr, [key]) + return self.append_command(RequestType.Incr, [key]) - def incrby(self, key: str, amount: int): + def incrby(self: TTransaction, key: str, amount: int) -> TTransaction: """ Increments the number stored at `key` by `amount`. If the key does not exist, it is set to 0 before performing the operation. @@ -270,9 +276,9 @@ def incrby(self, key: str, amount: int): Command response: int: The value of `key` after the increment. """ - self.append_command(RequestType.IncrBy, [key, str(amount)]) + return self.append_command(RequestType.IncrBy, [key, str(amount)]) - def incrbyfloat(self, key: str, amount: float): + def incrbyfloat(self: TTransaction, key: str, amount: float) -> TTransaction: """ Increment the string representing a floating point number stored at `key` by `amount`. By using a negative increment value, the value stored at the `key` is decremented. @@ -286,9 +292,9 @@ def incrbyfloat(self, key: str, amount: float): Command response: float: The value of key after the increment. """ - self.append_command(RequestType.IncrByFloat, [key, str(amount)]) + return self.append_command(RequestType.IncrByFloat, [key, str(amount)]) - def ping(self, message: Optional[str] = None): + def ping(self: TTransaction, message: Optional[str] = None) -> TTransaction: """ Ping the Redis server. See https://redis.io/commands/ping/ for more details. @@ -301,9 +307,9 @@ def ping(self, message: Optional[str] = None): str: "PONG" if `message` is not provided, otherwise return a copy of `message`. """ argument = [] if message is None else [message] - self.append_command(RequestType.Ping, argument) + return self.append_command(RequestType.Ping, argument) - def decr(self, key: str): + def decr(self: TTransaction, key: str) -> TTransaction: """ Decrements the number stored at `key` by one. If the key does not exist, it is set to 0 before performing the operation. @@ -315,9 +321,9 @@ def decr(self, key: str): Command response: int: the value of `key` after the decrement. """ - self.append_command(RequestType.Decr, [key]) + return self.append_command(RequestType.Decr, [key]) - def decrby(self, key: str, amount: int): + def decrby(self: TTransaction, key: str, amount: int) -> TTransaction: """ Decrements the number stored at `key` by `amount`. If the key does not exist, it is set to 0 before performing the operation. @@ -330,9 +336,11 @@ def decrby(self, key: str, amount: int): Command response: int: The value of `key` after the decrement. """ - self.append_command(RequestType.DecrBy, [key, str(amount)]) + return self.append_command(RequestType.DecrBy, [key, str(amount)]) - def hset(self, key: str, field_value_map: Mapping[str, str]): + def hset( + self: TTransaction, key: str, field_value_map: Mapping[str, str] + ) -> TTransaction: """ Sets the specified fields to their respective values in the hash stored at `key`. See https://redis.io/commands/hset/ for more details. @@ -348,9 +356,9 @@ def hset(self, key: str, field_value_map: Mapping[str, str]): field_value_list: List[str] = [key] for pair in field_value_map.items(): field_value_list.extend(pair) - self.append_command(RequestType.HashSet, field_value_list) + return self.append_command(RequestType.HashSet, field_value_list) - def hget(self, key: str, field: str): + def hget(self: TTransaction, key: str, field: str) -> TTransaction: """ Retrieves the value associated with `field` in the hash stored at `key`. See https://redis.io/commands/hget/ for more details. @@ -363,9 +371,9 @@ def hget(self, key: str, field: str): Optional[str]: The value associated `field` in the hash. Returns None if `field` is not presented in the hash or `key` does not exist. """ - self.append_command(RequestType.HashGet, [key, field]) + return self.append_command(RequestType.HashGet, [key, field]) - def hincrby(self, key: str, field: str, amount: int): + def hincrby(self: TTransaction, key: str, field: str, amount: int) -> TTransaction: """ Increment or decrement the value of a `field` in the hash stored at `key` by the specified amount. By using a negative increment value, the value stored at `field` in the hash stored at `key` is decremented. @@ -381,9 +389,11 @@ def hincrby(self, key: str, field: str, amount: int): Command response: int: The value of the specified field in the hash stored at `key` after the increment or decrement. """ - self.append_command(RequestType.HashIncrBy, [key, field, str(amount)]) + return self.append_command(RequestType.HashIncrBy, [key, field, str(amount)]) - def hincrbyfloat(self, key: str, field: str, amount: float): + def hincrbyfloat( + self: TTransaction, key: str, field: str, amount: float + ) -> TTransaction: """ Increment or decrement the floating-point value stored at `field` in the hash stored at `key` by the specified amount. @@ -400,9 +410,11 @@ def hincrbyfloat(self, key: str, field: str, amount: float): Command response: float: The value of the specified field in the hash stored at `key` after the increment as a string. """ - self.append_command(RequestType.HashIncrByFloat, [key, field, str(amount)]) + return self.append_command( + RequestType.HashIncrByFloat, [key, field, str(amount)] + ) - def hexists(self, key: str, field: str): + def hexists(self: TTransaction, key: str, field: str) -> TTransaction: """ Check if a field exists in the hash stored at `key`. See https://redis.io/commands/hexists/ for more details. @@ -415,9 +427,9 @@ def hexists(self, key: str, field: str): bool: Returns 'True' if the hash contains the specified field. If the hash does not contain the field, or if the key does not exist, it returns 'False'. """ - self.append_command(RequestType.HashExists, [key, field]) + return self.append_command(RequestType.HashExists, [key, field]) - def hlen(self, key: str): + def hlen(self: TTransaction, key: str) -> TTransaction: """ Returns the number of fields contained in the hash stored at `key`. @@ -430,9 +442,9 @@ def hlen(self, key: str): int: The number of fields in the hash, or 0 when the key does not exist. If `key` holds a value that is not a hash, the transaction fails with an error. """ - self.append_command(RequestType.HLen, [key]) + return self.append_command(RequestType.HLen, [key]) - def client_getname(self): + def client_getname(self: TTransaction) -> TTransaction: """ Get the name of the connection on which the transaction is being executed. See https://redis.io/commands/client-getname/ for more details. @@ -441,9 +453,9 @@ def client_getname(self): Optional[str]: Returns the name of the client connection as a string if a name is set, or None if no name is assigned. """ - self.append_command(RequestType.ClientGetName, []) + return self.append_command(RequestType.ClientGetName, []) - def hgetall(self, key: str): + def hgetall(self: TTransaction, key: str) -> TTransaction: """ Returns all fields and values of the hash stored at `key`. See https://redis.io/commands/hgetall/ for details. @@ -456,9 +468,9 @@ def hgetall(self, key: str): its value. If `key` does not exist, it returns an empty dictionary. """ - self.append_command(RequestType.HashGetAll, [key]), + return self.append_command(RequestType.HashGetAll, [key]) - def hmget(self, key: str, fields: List[str]): + def hmget(self: TTransaction, key: str, fields: List[str]) -> TTransaction: """ Retrieve the values associated with specified fields in the hash stored at `key`. See https://redis.io/commands/hmget/ for details. @@ -472,9 +484,9 @@ def hmget(self, key: str, fields: List[str]): For every field that does not exist in the hash, a null value is returned. If `key` does not exist, it is treated as an empty hash, and the function returns a list of null values. """ - self.append_command(RequestType.HashMGet, [key] + fields) + return self.append_command(RequestType.HashMGet, [key] + fields) - def hdel(self, key: str, fields: List[str]): + def hdel(self: TTransaction, key: str, fields: List[str]) -> TTransaction: """ Remove specified fields from the hash stored at `key`. See https://redis.io/commands/hdel/ for more details. @@ -487,9 +499,9 @@ def hdel(self, key: str, fields: List[str]): int: The number of fields that were removed from the hash, excluding specified but non-existing fields. If `key` does not exist, it is treated as an empty hash, and the function returns 0. """ - self.append_command(RequestType.HashDel, [key] + fields) + return self.append_command(RequestType.HashDel, [key] + fields) - def lpush(self, key: str, elements: List[str]): + def lpush(self: TTransaction, key: str, elements: List[str]) -> TTransaction: """ Insert all the specified values at the head of the list stored at `key`. `elements` are inserted one after the other to the head of the list, from the leftmost element @@ -503,9 +515,9 @@ def lpush(self, key: str, elements: List[str]): Command response: int: The length of the list after the push operations. """ - self.append_command(RequestType.LPush, [key] + elements) + return self.append_command(RequestType.LPush, [key] + elements) - def lpop(self, key: str): + def lpop(self: TTransaction, key: str) -> TTransaction: """ Remove and return the first elements of the list stored at `key`. The command pops a single element from the beginning of the list. @@ -518,9 +530,9 @@ def lpop(self, key: str): Optional[str]: The value of the first element. If `key` does not exist, None will be returned. """ - self.append_command(RequestType.LPop, [key]) + return self.append_command(RequestType.LPop, [key]) - def lpop_count(self, key: str, count: int): + def lpop_count(self: TTransaction, key: str, count: int) -> TTransaction: """ Remove and return up to `count` elements from the list stored at `key`, depending on the list's length. See https://redis.io/commands/lpop/ for details. @@ -533,9 +545,9 @@ def lpop_count(self, key: str, count: int): Optional[List[str]]: A a list of popped elements will be returned depending on the list's length. If `key` does not exist, None will be returned. """ - self.append_command(RequestType.LPop, [key, str(count)]) + return self.append_command(RequestType.LPop, [key, str(count)]) - def lrange(self, key: str, start: int, end: int): + def lrange(self: TTransaction, key: str, start: int, end: int) -> TTransaction: """ Retrieve the specified elements of the list stored at `key` within the given range. The offsets `start` and `end` are zero-based indexes, with 0 being the first element of the list, 1 being the next @@ -554,9 +566,9 @@ def lrange(self, key: str, start: int, end: int): If `end` exceeds the actual end of the list, the range will stop at the actual end of the list. If `key` does not exist an empty list will be returned. """ - self.append_command(RequestType.LRange, [key, str(start), str(end)]) + return self.append_command(RequestType.LRange, [key, str(start), str(end)]) - def rpush(self, key: str, elements: List[str]): + def rpush(self: TTransaction, key: str, elements: List[str]) -> TTransaction: """Inserts all the specified values at the tail of the list stored at `key`. `elements` are inserted one after the other to the tail of the list, from the leftmost element to the rightmost element. If `key` does not exist, it is created as empty list before performing the push operations. @@ -570,9 +582,9 @@ def rpush(self, key: str, elements: List[str]): int: The length of the list after the push operations. If `key` holds a value that is not a list, the transaction fails. """ - self.append_command(RequestType.RPush, [key] + elements) + return self.append_command(RequestType.RPush, [key] + elements) - def rpop(self, key: str, count: Optional[int] = None): + def rpop(self: TTransaction, key: str, count: Optional[int] = None) -> TTransaction: """ Removes and returns the last elements of the list stored at `key`. The command pops a single element from the end of the list. @@ -585,9 +597,9 @@ def rpop(self, key: str, count: Optional[int] = None): Optional[str]: The value of the last element. If `key` does not exist, None will be returned. """ - self.append_command(RequestType.RPop, [key]) + return self.append_command(RequestType.RPop, [key]) - def rpop_count(self, key: str, count: int): + def rpop_count(self: TTransaction, key: str, count: int) -> TTransaction: """ Removes and returns up to `count` elements from the list stored at `key`, depending on the list's length. See https://redis.io/commands/rpop/ for details. @@ -600,9 +612,9 @@ def rpop_count(self, key: str, count: int): Optional[List[str]: A list of popped elements will be returned depending on the list's length. If `key` does not exist, None will be returned. """ - self.append_command(RequestType.RPop, [key, str(count)]) + return self.append_command(RequestType.RPop, [key, str(count)]) - def sadd(self, key: str, members: List[str]): + def sadd(self: TTransaction, key: str, members: List[str]) -> TTransaction: """ Add specified members to the set stored at `key`. Specified members that are already a member of this set are ignored. @@ -616,9 +628,9 @@ def sadd(self, key: str, members: List[str]): Command response: int: The number of members that were added to the set, excluding members already present. """ - self.append_command(RequestType.SAdd, [key] + members) + return self.append_command(RequestType.SAdd, [key] + members) - def srem(self, key: str, members: List[str]): + def srem(self: TTransaction, key: str, members: List[str]) -> TTransaction: """ Remove specified members from the set stored at `key`. Specified members that are not a member of this set are ignored. @@ -632,9 +644,9 @@ def srem(self, key: str, members: List[str]): int: The number of members that were removed from the set, excluding non-existing members. If `key` does not exist, it is treated as an empty set and this command returns 0. """ - self.append_command(RequestType.SRem, [key] + members) + return self.append_command(RequestType.SRem, [key] + members) - def smembers(self, key: str): + def smembers(self: TTransaction, key: str) -> TTransaction: """ Retrieve all the members of the set value stored at `key`. See https://redis.io/commands/smembers/ for details. @@ -646,9 +658,9 @@ def smembers(self, key: str): Set[str]: A set of all members of the set. If `key` does not exist an empty list will be returned. """ - self.append_command(RequestType.SMembers, [key]) + return self.append_command(RequestType.SMembers, [key]) - def scard(self, key: str): + def scard(self: TTransaction, key: str) -> TTransaction: """ Retrieve the set cardinality (number of elements) of the set stored at `key`. See https://redis.io/commands/scard/ for details. @@ -659,9 +671,9 @@ def scard(self, key: str): Commands response: int: The cardinality (number of elements) of the set, or 0 if the key does not exist. """ - self.append_command(RequestType.SCard, [key]) + return self.append_command(RequestType.SCard, [key]) - def ltrim(self, key: str, start: int, end: int): + def ltrim(self: TTransaction, key: str, start: int, end: int) -> TTransaction: """ Trim an existing list so that it will contain only the specified range of elements specified. The offsets `start` and `end` are zero-based indexes, with 0 being the first element of the list, 1 being the next @@ -682,9 +694,9 @@ def ltrim(self, key: str, start: int, end: int): If `end` exceeds the actual end of the list, it will be treated like the last element of the list. f `key` does not exist, the response will be "OK" without changes to the database. """ - self.append_command(RequestType.LTrim, [key, str(start), str(end)]) + return self.append_command(RequestType.LTrim, [key, str(start), str(end)]) - def lrem(self, key: str, count: int, element: str): + def lrem(self: TTransaction, key: str, count: int, element: str) -> TTransaction: """ Removes the first `count` occurrences of elements equal to `element` from the list stored at `key`. If `count` is positive, it removes elements equal to `element` moving from head to tail. @@ -702,9 +714,9 @@ def lrem(self, key: str, count: int, element: str): int: The number of removed elements. If `key` does not exist, 0 is returned. """ - self.append_command(RequestType.LRem, [key, str(count), element]) + return self.append_command(RequestType.LRem, [key, str(count), element]) - def llen(self, key: str): + def llen(self: TTransaction, key: str) -> TTransaction: """ Get the length of the list stored at `key`. See https://redis.io/commands/llen/ for details. @@ -716,9 +728,9 @@ def llen(self, key: str): int: The length of the list at the specified key. If `key` does not exist, it is interpreted as an empty list and 0 is returned. """ - self.append_command(RequestType.LLen, [key]) + return self.append_command(RequestType.LLen, [key]) - def exists(self, keys: List[str]): + def exists(self: TTransaction, keys: List[str]) -> TTransaction: """ Returns the number of keys in `keys` that exist in the database. See https://redis.io/commands/exists/ for more details. @@ -730,9 +742,9 @@ def exists(self, keys: List[str]): int: The number of keys that exist. If the same existing key is mentioned in `keys` multiple times, it will be counted multiple times. """ - self.append_command(RequestType.Exists, keys) + return self.append_command(RequestType.Exists, keys) - def unlink(self, keys: List[str]): + def unlink(self: TTransaction, keys: List[str]) -> TTransaction: """ Unlink (delete) multiple keys from the database. A key is ignored if it does not exist. @@ -746,9 +758,14 @@ def unlink(self, keys: List[str]): Commands response: int: The number of keys that were unlinked. """ - self.append_command(RequestType.Unlink, keys) + return self.append_command(RequestType.Unlink, keys) - def expire(self, key: str, seconds: int, option: Optional[ExpireOptions] = None): + def expire( + self: TTransaction, + key: str, + seconds: int, + option: Optional[ExpireOptions] = None, + ) -> TTransaction: """ Sets a timeout on `key` in seconds. After the timeout has expired, the key will automatically be deleted. If `key` already has an existing expire set, the time to live is updated to the new value. @@ -768,11 +785,14 @@ def expire(self, key: str, seconds: int, option: Optional[ExpireOptions] = None) args: List[str] = ( [key, str(seconds)] if option is None else [key, str(seconds), option.value] ) - self.append_command(RequestType.Expire, args) + return self.append_command(RequestType.Expire, args) def expireat( - self, key: str, unix_seconds: int, option: Optional[ExpireOptions] = None - ): + self: TTransaction, + key: str, + unix_seconds: int, + option: Optional[ExpireOptions] = None, + ) -> TTransaction: """ Sets a timeout on `key` using an absolute Unix timestamp (seconds since January 1, 1970) instead of specifying the number of seconds. @@ -796,11 +816,14 @@ def expireat( if option is None else [key, str(unix_seconds), option.value] ) - self.append_command(RequestType.ExpireAt, args) + return self.append_command(RequestType.ExpireAt, args) def pexpire( - self, key: str, milliseconds: int, option: Optional[ExpireOptions] = None - ): + self: TTransaction, + key: str, + milliseconds: int, + option: Optional[ExpireOptions] = None, + ) -> TTransaction: """ Sets a timeout on `key` in milliseconds. After the timeout has expired, the key will automatically be deleted. If `key` already has an existing expire set, the time to live is updated to the new value. @@ -822,11 +845,14 @@ def pexpire( if option is None else [key, str(milliseconds), option.value] ) - self.append_command(RequestType.PExpire, args) + return self.append_command(RequestType.PExpire, args) def pexpireat( - self, key: str, unix_milliseconds: int, option: Optional[ExpireOptions] = None - ): + self: TTransaction, + key: str, + unix_milliseconds: int, + option: Optional[ExpireOptions] = None, + ) -> TTransaction: """ Sets a timeout on `key` using an absolute Unix timestamp in milliseconds (milliseconds since January 1, 1970) instead of specifying the number of milliseconds. @@ -850,9 +876,9 @@ def pexpireat( if option is None else [key, str(unix_milliseconds), option.value] ) - self.append_command(RequestType.PExpireAt, args) + return self.append_command(RequestType.PExpireAt, args) - def ttl(self, key: str): + def ttl(self: TTransaction, key: str) -> TTransaction: """ Returns the remaining time to live of `key` that has a timeout. See https://redis.io/commands/ttl/ for more details. @@ -863,9 +889,9 @@ def ttl(self, key: str): Commands response: int: TTL in seconds, -2 if `key` does not exist or -1 if `key` exists but has no associated expire. """ - self.append_command(RequestType.TTL, [key]) + return self.append_command(RequestType.TTL, [key]) - def type(self, key: str): + def type(self: TTransaction, key: str) -> TTransaction: """ Returns the string representation of the type of the value stored at `key`. @@ -878,16 +904,16 @@ def type(self, key: str): str: If the key exists, the type of the stored value is returned. Otherwise, a "none" string is returned. """ - self.append_command(RequestType.Type, [key]) + return self.append_command(RequestType.Type, [key]) def zadd( - self, + self: TTransaction, key: str, members_scores: Mapping[str, float], existing_options: Optional[ConditionalChange] = None, update_condition: Optional[UpdateOptions] = None, changed: bool = False, - ): + ) -> TTransaction: """ Adds members with their scores to the sorted set stored at `key`. If a member is already a part of the sorted set, its score is updated. @@ -931,16 +957,16 @@ def zadd( ] args += members_scores_list - self.append_command(RequestType.Zadd, args) + return self.append_command(RequestType.Zadd, args) def zadd_incr( - self, + self: TTransaction, key: str, member: str, increment: float, existing_options: Optional[ConditionalChange] = None, update_condition: Optional[UpdateOptions] = None, - ): + ) -> TTransaction: """ Increments the score of member in the sorted set stored at `key` by `increment`. If `member` does not exist in the sorted set, it is added with `increment` as its score (as if its previous score was 0.0). @@ -980,9 +1006,9 @@ def zadd_incr( ) args += [str(increment), member] - self.append_command(RequestType.Zadd, args) + return self.append_command(RequestType.Zadd, args) - def zcard(self, key: str): + def zcard(self: TTransaction, key: str) -> TTransaction: """ Returns the cardinality (number of elements) of the sorted set stored at `key`. @@ -995,14 +1021,14 @@ def zcard(self, key: str): int: The number of elements in the sorted set. If `key` does not exist, it is treated as an empty sorted set, and the command returns 0. """ - self.append_command(RequestType.Zcard, [key]) + return self.append_command(RequestType.Zcard, [key]) def zcount( - self, + self: TTransaction, key: str, min_score: Union[InfBound, ScoreLimit], max_score: Union[InfBound, ScoreLimit], - ): + ) -> TTransaction: """ Returns the number of members in the sorted set stored at `key` with scores between `min_score` and `max_score`. @@ -1022,13 +1048,15 @@ def zcount( If key does not exist, 0 is returned. If `max_score` < `min_score`, 0 is returned. """ - self.append_command(RequestType.Zcount, [key, min_score.value, max_score.value]) + return self.append_command( + RequestType.Zcount, [key, min_score.value, max_score.value] + ) def zrem( - self, + self: TTransaction, key: str, members: List[str], - ): + ) -> TTransaction: """ Removes the specified members from the sorted set stored at `key`. Specified members that are not a member of this set are ignored. @@ -1043,9 +1071,9 @@ def zrem( int: The number of members that were removed from the sorted set, not including non-existing members. If `key` does not exist, it is treated as an empty sorted set, and the command returns 0. """ - self.append_command(RequestType.Zrem, [key] + members) + return self.append_command(RequestType.Zrem, [key] + members) - def zscore(self, key: str, member: str): + def zscore(self: TTransaction, key: str, member: str) -> TTransaction: """ Returns the score of `member` in the sorted set stored at `key`. @@ -1060,7 +1088,7 @@ def zscore(self, key: str, member: str): If `member` does not exist in the sorted set, None is returned. If `key` does not exist, None is returned. """ - self.append_command(RequestType.ZScore, [key, member]) + return self.append_command(RequestType.ZScore, [key, member]) class Transaction(BaseTransaction): @@ -1082,7 +1110,7 @@ class Transaction(BaseTransaction): """ # TODO: add MOVE, SLAVEOF and all SENTINEL commands - def select(self, index: int): + def select(self, index: int) -> "Transaction": """ Change the currently selected Redis database. See https://redis.io/commands/select/ for details. @@ -1093,7 +1121,7 @@ def select(self, index: int): Command response: A simple OK response. """ - self.append_command(RequestType.Select, [str(index)]) + return self.append_command(RequestType.Select, [str(index)]) class ClusterTransaction(BaseTransaction): diff --git a/python/python/tests/test_transaction.py b/python/python/tests/test_transaction.py index 2a73b145d1..2926dfad03 100644 --- a/python/python/tests/test_transaction.py +++ b/python/python/tests/test_transaction.py @@ -291,3 +291,14 @@ def test_transaction_clear(self): transaction.select(1) transaction.clear() assert len(transaction.commands) == 0 + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_transaction_chaining_calls(self, redis_client: TRedisClient): + cluster_mode = isinstance(redis_client, RedisClusterClient) + key = get_random_string(3) + + transaction = ClusterTransaction() if cluster_mode else Transaction() + transaction.set(key, "value").get(key).delete([key]) + + assert await redis_client.exec(transaction) == [OK, "value", 1]