Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PubSub with RESP3 parser #2721

Merged
merged 5 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from redis.lock import Lock
from redis.retry import Retry
from redis.utils import safe_str, str_if_bytes
from redis.utils import HIREDIS_AVAILABLE, safe_str, str_if_bytes

SYM_EMPTY = b""
EMPTY_RESPONSE = "EMPTY_RESPONSE"
Expand Down Expand Up @@ -1429,6 +1429,7 @@ def __init__(
shard_hint=None,
ignore_subscribe_messages=False,
encoder=None,
push_handler=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

push_handler_func

and the type hint that goes with it

):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
Expand All @@ -1438,6 +1439,7 @@ def __init__(
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
self.push_handler = push_handler
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
Expand Down Expand Up @@ -1515,6 +1517,8 @@ def execute_command(self, *args):
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
if self.push_handler is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
Expand Down Expand Up @@ -1580,7 +1584,7 @@ def try_read():
return None
else:
conn.connect()
return conn.read_response()
return conn.read_response(push_request=True)

response = self._execute(conn, try_read)

Expand Down Expand Up @@ -1739,8 +1743,8 @@ def ping(self, message=None):
"""
Ping the Redis server
"""
message = "" if message is None else message
return self.execute_command("PING", message)
args = ["PING", message] if message is not None else ["PING"]
return self.execute_command(*args)

def handle_message(self, response, ignore_subscribe_messages=False):
"""
Expand All @@ -1750,6 +1754,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
"""
if response is None:
return None
if isinstance(response, bytes):
response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
message_type = str_if_bytes(response[0])
if message_type == "pmessage":
message = {
Expand Down
12 changes: 9 additions & 3 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,18 @@ def can_read(self, timeout=0):
self.disconnect()
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")

def read_response(self, disable_decoding=False):
def read_response(self, disable_decoding=False, push_request=False):
"""Read the response from a previously sent command"""

host_error = self._host_error()

try:
response = self._parser.read_response(disable_decoding=disable_decoding)
if self.protocol == "3" and not HIREDIS_AVAILABLE:
response = self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
response = self._parser.read_response(disable_decoding=disable_decoding)
except socket.timeout:
self.disconnect()
raise TimeoutError(f"Timeout reading from {host_error}")
Expand Down Expand Up @@ -705,8 +710,9 @@ def _connect(self):
class UnixDomainSocketConnection(AbstractConnection):
"Manages UDS communication to and from a Redis server"

def __init__(self, path="", **kwargs):
def __init__(self, path="", socket_timeout=None, **kwargs):
self.path = path
self.socket_timeout = socket_timeout
super().__init__(**kwargs)

def repr_pieces(self):
Expand Down
39 changes: 35 additions & 4 deletions redis/parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,36 @@

from ..exceptions import ConnectionError, InvalidResponse, ResponseError
from ..typing import EncodableT
from ..utils import INFO_LOGGER
from .base import _AsyncRESPBase, _RESPBase
from .socket import SERVER_CLOSED_CONNECTION_ERROR


class _RESP3Parser(_RESPBase):
"""RESP3 protocol implementation"""

def read_response(self, disable_decoding=False):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.push_handler = self.handle_push_response

def handle_push_response(self, response):
INFO_LOGGER.info("Push response: " + str(response))
return response

def read_response(self, disable_decoding=False, push_request=False):
pos = self._buffer.get_pos()
try:
result = self._read_response(disable_decoding=disable_decoding)
result = self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
except BaseException:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result

def _read_response(self, disable_decoding=False):
def _read_response(self, disable_decoding=False, push_request=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
Expand Down Expand Up @@ -77,16 +88,36 @@ def _read_response(self, disable_decoding=False):
response = {
self._read_response(
disable_decoding=disable_decoding
): self._read_response(disable_decoding=disable_decoding)
): self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
for _ in range(int(response))
}
# push response
elif byte == b">":
response = [
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
for _ in range(int(response))
]
res = self.push_handler(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

def set_push_handler(self, push_handler):
self.push_handler = push_handler


class _AsyncRESP3Parser(_AsyncRESPBase):
async def read_response(self, disable_decoding: bool = False):
Expand Down
10 changes: 10 additions & 0 deletions redis/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from contextlib import contextmanager
from functools import wraps
from typing import Any, Dict, Mapping, Union
Expand Down Expand Up @@ -117,3 +118,12 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


logger = logging.getLogger("push_response")
chayim marked this conversation as resolved.
Show resolved Hide resolved
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)

INFO_LOGGER = logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

INFO_LOGGER Should instead be _info_logger, so that we don't get caught in the same issue as the parsers rename

37 changes: 31 additions & 6 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@

import redis
from redis.exceptions import ConnectionError
from redis.utils import HIREDIS_AVAILABLE

from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt
from .conftest import (
_get_client,
is_resp2_connection,
skip_if_redis_enterprise,
skip_if_server_version_lt,
)


def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False):
Expand Down Expand Up @@ -352,6 +358,23 @@ def test_unicode_pattern_message_handler(self, r):
)


class TestPubSubRESP3Handler:
def my_handler(self, message):
self.message = ["my handler", message]

@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
def test_push_handler(self, r):
if is_resp2_connection(r):
return
p = r.pubsub(push_handler=self.my_handler)
p.subscribe("foo")
assert wait_for_message(p) is None
assert self.message == ["my handler", [b"subscribe", b"foo", 1]]
assert r.publish("foo", "test message") == 1
assert wait_for_message(p) is None
assert self.message == ["my handler", [b"message", b"foo", b"test message"]]


class TestPubSubAutoDecoding:
"These tests only validate that we get unicode values back"

Expand Down Expand Up @@ -767,13 +790,15 @@ def get_msg():
assert msg is not None
# timeout waiting for another message which never arrives
assert is_connected()
with patch("redis.parsers._RESP2Parser.read_response") as mock1:
with patch("redis.parsers._RESP2Parser.read_response") as mock1, patch(
"redis.parsers._HiredisParser.read_response"
) as mock2, patch("redis.parsers._RESP3Parser.read_response") as mock3:
mock1.side_effect = BaseException("boom")
with patch("redis.parsers._HiredisParser.read_response") as mock2:
mock2.side_effect = BaseException("boom")
mock2.side_effect = BaseException("boom")
mock3.side_effect = BaseException("boom")

with pytest.raises(BaseException):
get_msg()
with pytest.raises(BaseException):
get_msg()

# the timeout on the read should not cause disconnect
assert is_connected()