Skip to content

Commit

Permalink
RESP3 tests (#2780)
Browse files Browse the repository at this point in the history
* fix command response in resp3

* linters

* acl_log & acl_getuser

* client_info

* test_commands and test_asyncio/test_commands

* fix test_command_parser

* fix asyncio/test_connection/test_invalid_response

* linters

* all the tests

* push handler sharded pubsub

* Use assert_resp_response wherever possible

* fix test_xreadgroup

* fix cluster_zdiffstore and cluster_zinter

* fix review comments

* fix review comments

* linters
  • Loading branch information
dvora-h authored Jun 1, 2023
1 parent e8fc092 commit 326f351
Show file tree
Hide file tree
Showing 19 changed files with 812 additions and 705 deletions.
8 changes: 4 additions & 4 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,13 +671,13 @@ def __init__(
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
if self.encoder.decode_responses:
self.health_check_response: Iterable[Union[str, bytes]] = [
"pong",
self.health_check_response = [
["pong", self.HEALTH_CHECK_MESSAGE],
self.HEALTH_CHECK_MESSAGE,
]
else:
self.health_check_response = [
b"pong",
[b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
]
if self.push_handler_func is None:
Expand Down Expand Up @@ -807,7 +807,7 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
conn, conn.read_response, timeout=read_timeout, push_request=True
)

if conn.health_check_interval and response == self.health_check_response:
if conn.health_check_interval and response in self.health_check_response:
# ignore the health check message as user might not expect it
return None
return response
Expand Down
2 changes: 2 additions & 0 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def __init__(
kwargs.update({"retry": self.retry})

kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy()
if kwargs.get("protocol") in ["3", 3]:
kwargs["response_callbacks"].update(self.__class__.RESP3_RESPONSE_CALLBACKS)
self.connection_kwargs = kwargs

if startup_nodes:
Expand Down
28 changes: 25 additions & 3 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,16 +333,36 @@ def _error_message(self, exception):
async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
self._parser.on_connect(self)
parser = self._parser

auth_args = None
# if credential provider or username and/or password are set, authenticate
if self.credential_provider or (self.username or self.password):
cred_provider = (
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = cred_provider.get_credentials()
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol not in [2, "2"]:
if isinstance(self._parser, _AsyncRESP2Parser):
self.set_parser(_AsyncRESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
await self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
response = await self.read_response()
if response.get(b"proto") not in [2, "2"] and response.get("proto") not in [
2,
"2",
]:
raise ConnectionError("Invalid RESP version")
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
elif auth_args:
await self.send_command("AUTH", *auth_args, check_health=False)

try:
Expand All @@ -359,9 +379,11 @@ async def on_connect(self) -> None:
raise AuthenticationError("Invalid Username or Password")

# if resp version is specified, switch to it
if self.protocol != 2:
elif self.protocol != 2:
if isinstance(self._parser, _AsyncRESP2Parser):
self.set_parser(_AsyncRESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
await self.send_command("HELLO", self.protocol)
response = await self.read_response()
Expand Down
38 changes: 24 additions & 14 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,15 @@ def parse_xinfo_stream(response, **options):
data["last-entry"] = (last[0], pairs_to_dict(last[1]))
else:
data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]}
data["groups"] = [
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
]
if isinstance(data["groups"][0], list):
data["groups"] = [
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
]
else:
data["groups"] = [
{str_if_bytes(k): v for k, v in group.items()}
for group in data["groups"]
]
return data


Expand Down Expand Up @@ -581,14 +587,15 @@ def parse_command_resp3(response, **options):
cmd_name = str_if_bytes(command[0])
cmd_dict["name"] = cmd_name
cmd_dict["arity"] = command[1]
cmd_dict["flags"] = command[2]
cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]}
cmd_dict["first_key_pos"] = command[3]
cmd_dict["last_key_pos"] = command[4]
cmd_dict["step_count"] = command[5]
cmd_dict["acl_categories"] = command[6]
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]
if len(command) > 7:
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]

commands[cmd_name] = cmd_dict
return commands
Expand Down Expand Up @@ -626,17 +633,20 @@ def parse_acl_getuser(response, **options):
if data["channels"] == [""]:
data["channels"] = []
if "selectors" in data:
data["selectors"] = [
list(map(str_if_bytes, selector)) for selector in data["selectors"]
]
if data["selectors"] != [] and isinstance(data["selectors"][0], list):
data["selectors"] = [
list(map(str_if_bytes, selector)) for selector in data["selectors"]
]
elif data["selectors"] != []:
data["selectors"] = [
{str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()}
for selector in data["selectors"]
]

# split 'commands' into separate 'categories' and 'commands' lists
commands, categories = [], []
for command in data["commands"].split(" "):
if "@" in command:
categories.append(command)
else:
commands.append(command)
categories.append(command) if "@" in command else commands.append(command)

data["commands"] = commands
data["categories"] = categories
Expand Down
22 changes: 19 additions & 3 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from redis.parsers import CommandsParser, Encoder
from redis.retry import Retry
from redis.utils import (
HIREDIS_AVAILABLE,
dict_merge,
list_keys_to_dict,
merge_result,
Expand Down Expand Up @@ -1608,7 +1609,15 @@ class ClusterPubSub(PubSub):
https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html
"""

def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
def __init__(
self,
redis_cluster,
node=None,
host=None,
port=None,
push_handler_func=None,
**kwargs,
):
"""
When a pubsub instance is created without specifying a node, a single
node will be transparently chosen for the pubsub connection on the
Expand All @@ -1633,7 +1642,10 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
self.node_pubsub_mapping = {}
self._pubsubs_generator = self._pubsubs_generator()
super().__init__(
**kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder
connection_pool=connection_pool,
encoder=redis_cluster.encoder,
push_handler_func=push_handler_func,
**kwargs,
)

def set_pubsub_node(self, cluster, node=None, host=None, port=None):
Expand Down Expand Up @@ -1717,14 +1729,18 @@ def execute_command(self, *args):
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
connection = self.connection
self._execute(connection, connection.send_command, *args)

def _get_node_pubsub(self, node):
try:
return self.node_pubsub_mapping[node.name]
except KeyError:
pubsub = node.redis_connection.pubsub()
pubsub = node.redis_connection.pubsub(
push_handler_func=self.push_handler_func
)
self.node_pubsub_mapping[node.name] = pubsub
return pubsub

Expand Down
23 changes: 22 additions & 1 deletion redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,33 @@ def _error_message(self, exception):
def on_connect(self):
"Initialize the connection, authenticate and select a database"
self._parser.on_connect(self)
parser = self._parser

auth_args = None
# if credential provider or username and/or password are set, authenticate
if self.credential_provider or (self.username or self.password):
cred_provider = (
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = cred_provider.get_credentials()
# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol != 2:
if isinstance(self._parser, _RESP2Parser):
self.set_parser(_RESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
response = self.read_response()
if response.get(b"proto") != int(self.protocol) and response.get(
"proto"
) != int(self.protocol):
raise ConnectionError("Invalid RESP version")
elif auth_args:
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
self.send_command("AUTH", *auth_args, check_health=False)
Expand All @@ -302,9 +321,11 @@ def on_connect(self):
raise AuthenticationError("Invalid Username or Password")

# if resp version is specified, switch to it
if self.protocol != 2:
elif self.protocol != 2:
if isinstance(self._parser, _RESP2Parser):
self.set_parser(_RESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
self.send_command("HELLO", self.protocol)
response = self.read_response()
Expand Down
3 changes: 2 additions & 1 deletion redis/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import BaseParser
from .base import BaseParser, _AsyncRESPBase
from .commands import AsyncCommandsParser, CommandsParser
from .encoders import Encoder
from .hiredis import _AsyncHiredisParser, _HiredisParser
Expand All @@ -8,6 +8,7 @@
__all__ = [
"AsyncCommandsParser",
"_AsyncHiredisParser",
"_AsyncRESPBase",
"_AsyncRESP2Parser",
"_AsyncRESP3Parser",
"CommandsParser",
Expand Down
14 changes: 10 additions & 4 deletions redis/parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ def _read_response(self, disable_decoding=False, push_request=False):
# bool value
elif byte == b"#":
return response == b"t"
# bulk response and verbatim strings
elif byte in (b"$", b"="):
# bulk response
elif byte == b"$":
response = self._buffer.read(int(response))
# verbatim string response
elif byte == b"=":
response = self._buffer.read(int(response))[4:]
# array response
elif byte == b"*":
response = [
Expand Down Expand Up @@ -195,9 +198,12 @@ async def _read_response(
# bool value
elif byte == b"#":
return response == b"t"
# bulk response and verbatim strings
elif byte in (b"$", b"="):
# bulk response
elif byte == b"$":
response = await self._read(int(response))
# verbatim string response
elif byte == b"=":
response = (await self._read(int(response)))[4:]
# array response
elif byte == b"*":
response = [
Expand Down
27 changes: 25 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,31 @@ def wait_for_command(client, monitor, command, key=None):


def is_resp2_connection(r):
if isinstance(r, redis.Redis):
if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis):
protocol = r.connection_pool.connection_kwargs.get("protocol")
elif isinstance(r, redis.RedisCluster):
elif isinstance(r, redis.cluster.AbstractRedisCluster):
protocol = r.nodes_manager.connection_kwargs.get("protocol")
return protocol in ["2", 2, None]


def get_protocol_version(r):
if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis):
return r.connection_pool.connection_kwargs.get("protocol")
elif isinstance(r, redis.cluster.AbstractRedisCluster):
return r.nodes_manager.connection_kwargs.get("protocol")


def assert_resp_response(r, response, resp2_expected, resp3_expected):
protocol = get_protocol_version(r)
if protocol in [2, "2", None]:
assert response == resp2_expected
else:
assert response == resp3_expected


def assert_resp_response_in(r, response, resp2_expected, resp3_expected):
protocol = get_protocol_version(r)
if protocol in [2, "2", None]:
assert response in resp2_expected
else:
assert response in resp3_expected
23 changes: 0 additions & 23 deletions tests/test_asyncio/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,29 +236,6 @@ async def wait_for_command(
return None


def get_protocol_version(r):
if isinstance(r, redis.Redis):
return r.connection_pool.connection_kwargs.get("protocol")
elif isinstance(r, redis.RedisCluster):
return r.nodes_manager.connection_kwargs.get("protocol")


def assert_resp_response(r, response, resp2_expected, resp3_expected):
protocol = get_protocol_version(r)
if protocol in [2, "2", None]:
assert response == resp2_expected
else:
assert response == resp3_expected


def assert_resp_response_in(r, response, resp2_expected, resp3_expected):
protocol = get_protocol_version(r)
if protocol in [2, "2", None]:
assert response in resp2_expected
else:
assert response in resp3_expected


# python 3.6 doesn't have the asynccontextmanager decorator. Provide it here.
class AsyncContextManager:
def __init__(self, async_generator):
Expand Down
Loading

0 comments on commit 326f351

Please sign in to comment.