Skip to content

Commit

Permalink
tests: add 'connect' tests for all Redis connection classes
Browse files Browse the repository at this point in the history
  • Loading branch information
woutdenolf committed Mar 22, 2023
1 parent 0f0050c commit 1344266
Showing 1 changed file with 166 additions and 0 deletions.
166 changes: 166 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import logging
import re
import socket
import ssl
import threading

import pytest

from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection

from .ssl_certificates import get_ssl_filename

_logger = logging.getLogger(__name__)


_CLIENT_NAME = "test-suite-client"
_CMD_SEP = b"\r\n"
_SUCCESS_RESP = b"+OK" + _CMD_SEP
_ERROR_RESP = b"-ERR" + _CMD_SEP
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}


@pytest.fixture
def tcp_address():
with socket.socket() as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()


@pytest.fixture
def uds_address(tmpdir):
return tmpdir / "uds.sock"


def test_tcp_connect(tcp_address):
host, port = tcp_address
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME)
_assert_connect(conn, tcp_address)


def test_uds_connect(uds_address):
path = str(uds_address)
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME)
_assert_connect(conn, path)


@pytest.mark.ssl
def test_tcp_ssl_connect(tcp_address):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host, port=port, client_name=_CLIENT_NAME, ssl_ca_certs=certfile
)
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


def _assert_connect(conn, server_address, certfile=None, keyfile=None):
ready = threading.Event()
stop = threading.Event()
t = threading.Thread(
target=_redis_mock_server,
args=(server_address, ready, stop),
kwargs={"certfile": certfile, "keyfile": keyfile},
)
t.start()
try:
ready.wait()
conn.connect()
conn.disconnect()
finally:
stop.set()
t.join(timeout=5)


def _redis_mock_server(server_address, ready, stop, certfile=None, keyfile=None):
try:
if isinstance(server_address, str):
family = socket.AF_UNIX
mockname = "Redis mock server (UDS)"
elif certfile:
family = socket.AF_INET
mockname = "Redis mock server (TCP-SSL)"
else:
family = socket.AF_INET
mockname = "Redis mock server (TCP)"

with socket.socket(family, socket.SOCK_STREAM) as s:
s.bind(server_address)
s.listen(1)
s.settimeout(0.1)

if certfile:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.load_cert_chain(certfile=certfile, keyfile=keyfile)

_logger.info("Start %s: %s", mockname, server_address)
ready.set()

# Wait for a client connection
while not stop.is_set():
try:
sconn, _ = s.accept()
sconn.settimeout(0.1)
break
except socket.timeout:
pass
if stop.is_set():
_logger.info("Exit %s: %s", mockname, server_address)
return

# Handle commands from the client
with sconn:
if certfile:
with context.wrap_socket(sconn, server_side=True) as wconn:
_redis_mock_server_handle(wconn, stop, mockname)
else:
_redis_mock_server_handle(sconn, stop, mockname)
_logger.info("Exit %s: %s", mockname, server_address)
except BaseException as e:
_logger.exception("Error in %s: %s", mockname, e)
raise


def _redis_mock_server_handle(conn, stop, mockname):
buffer = b""
command = None
command_ptr = None
fragment_length = None
while not stop.is_set() or buffer:
try:
buffer += conn.recv(1024)
except socket.timeout:
continue
if not buffer:
continue
parts = re.split(_CMD_SEP, buffer)
buffer = parts[-1]
for fragment in parts[:-1]:
fragment = fragment.decode()
_logger.info("Command fragment in %s: %s", mockname, fragment)

if fragment.startswith("*") and command is None:
command = [None for _ in range(int(fragment[1:]))]
command_ptr = 0
fragment_length = None
continue

if fragment.startswith("$") and command[command_ptr] is None:
fragment_length = int(fragment[1:])
continue

assert len(fragment) == fragment_length
command[command_ptr] = fragment
command_ptr += 1

if command_ptr < len(command):
continue

command = " ".join(command)
_logger.info("Command in %s: %s", mockname, command)
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
_logger.info("Response from %s: %s", mockname, resp)
conn.sendall(resp)
command = None

0 comments on commit 1344266

Please sign in to comment.