Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for the connect method of all Redis connection classes #2631

Merged
merged 3 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/ssl_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os


def get_ssl_filename(name):
root = os.path.join(os.path.dirname(__file__), "..")
cert_dir = os.path.abspath(os.path.join(root, "docker", "stunnel", "keys"))
if not os.path.isdir(cert_dir): # github actions package validation case
cert_dir = os.path.abspath(
os.path.join(root, "..", "docker", "stunnel", "keys")
)
if not os.path.isdir(cert_dir):
raise IOError(f"No SSL certificates found. They should be in {cert_dir}")

return os.path.join(cert_dir, name)
15 changes: 3 additions & 12 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import binascii
import datetime
import os
import warnings
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -36,6 +35,7 @@
skip_unless_arch_bits,
)

from ..ssl_utils import get_ssl_filename
from .compat import mock

pytestmark = pytest.mark.onlycluster
Expand Down Expand Up @@ -2744,17 +2744,8 @@ class TestSSL:
appropriate port.
"""

ROOT = os.path.join(os.path.dirname(__file__), "../..")
CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys"))
if not os.path.isdir(CERT_DIR): # github actions package validation case
CERT_DIR = os.path.abspath(
os.path.join(ROOT, "..", "docker", "stunnel", "keys")
)
if not os.path.isdir(CERT_DIR):
raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}")

SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem")
SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem")
SERVER_CERT = get_ssl_filename("server-cert.pem")
SERVER_KEY = get_ssl_filename("server-key.pem")

@pytest_asyncio.fixture()
def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]:
Expand Down
145 changes: 145 additions & 0 deletions tests/test_asyncio/test_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
import logging
import re
import socket
import ssl

import pytest

from redis.asyncio.connection import (
Connection,
SSLConnection,
UnixDomainSocketConnection,
)

from ..ssl_utils import get_ssl_filename

_logger = logging.getLogger(__name__)


_CLIENT_NAME = "test-suite-client"
_CMD_SEP = b"\r\n"
_SUCCESS_RESP = b"+OK" + _CMD_SEP
_ERROR_RESP = b"-ERR" + _CMD_SEP
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}


@pytest.fixture
def tcp_address():
with socket.socket() as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()


@pytest.fixture
def uds_address(tmpdir):
return tmpdir / "uds.sock"


async def test_tcp_connect(tcp_address):
host, port = tcp_address
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
await _assert_connect(conn, tcp_address)


async def test_uds_connect(uds_address):
path = str(uds_address)
conn = UnixDomainSocketConnection(
path=path, client_name=_CLIENT_NAME, socket_timeout=10
)
await _assert_connect(conn, path)


@pytest.mark.ssl
async def test_tcp_ssl_connect(tcp_address):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
)
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
stop_event = asyncio.Event()
finished = asyncio.Event()

async def _handler(reader, writer):
try:
return await _redis_request_handler(reader, writer, stop_event)
finally:
finished.set()

if isinstance(server_address, str):
server = await asyncio.start_unix_server(_handler, path=server_address)
elif certfile:
host, port = server_address
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
else:
host, port = server_address
server = await asyncio.start_server(_handler, host=host, port=port)

async with server as aserver:
await aserver.start_serving()
try:
await conn.connect()
await conn.disconnect()
finally:
stop_event.set()
aserver.close()
await aserver.wait_closed()
await finished.wait()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There must be a better way to make sure any running handlers are finished after closing the server.



async def _redis_request_handler(reader, writer, stop_event):
buffer = b""
command = None
command_ptr = None
fragment_length = None
while not stop_event.is_set() or buffer:
_logger.info(str(stop_event.is_set()))
try:
buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5)
except TimeoutError:
continue
if not buffer:
continue
parts = re.split(_CMD_SEP, buffer)
buffer = parts[-1]
for fragment in parts[:-1]:
fragment = fragment.decode()
_logger.info("Command fragment: %s", fragment)

if fragment.startswith("*") and command is None:
command = [None for _ in range(int(fragment[1:]))]
command_ptr = 0
fragment_length = None
continue

if fragment.startswith("$") and command[command_ptr] is None:
fragment_length = int(fragment[1:])
continue

assert len(fragment) == fragment_length
command[command_ptr] = fragment
command_ptr += 1

if command_ptr < len(command):
continue

command = " ".join(command)
_logger.info("Command %s", command)
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
_logger.info("Response from %s", resp)
writer.write(resp)
await writer.drain()
command = None
_logger.info("Exit handler")
185 changes: 185 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import logging
import re
import socket
import socketserver
import ssl
import threading

import pytest

from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection

from .ssl_utils import get_ssl_filename

_logger = logging.getLogger(__name__)


_CLIENT_NAME = "test-suite-client"
_CMD_SEP = b"\r\n"
_SUCCESS_RESP = b"+OK" + _CMD_SEP
_ERROR_RESP = b"-ERR" + _CMD_SEP
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}


@pytest.fixture
def tcp_address():
with socket.socket() as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()


@pytest.fixture
def uds_address(tmpdir):
return tmpdir / "uds.sock"


def test_tcp_connect(tcp_address):
host, port = tcp_address
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
_assert_connect(conn, tcp_address)


def test_uds_connect(uds_address):
path = str(uds_address)
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10)
_assert_connect(conn, path)


@pytest.mark.ssl
def test_tcp_ssl_connect(tcp_address):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
)
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


def _assert_connect(conn, server_address, certfile=None, keyfile=None):
if isinstance(server_address, str):
server = _RedisUDSServer(server_address, _RedisRequestHandler)
else:
server = _RedisTCPServer(
server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile
)
with server as aserver:
t = threading.Thread(target=aserver.serve_forever)
t.start()
try:
aserver.wait_online()
conn.connect()
conn.disconnect()
finally:
aserver.stop()
t.join(timeout=5)


class _RedisTCPServer(socketserver.TCPServer):
def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None:
self._ready_event = threading.Event()
self._stop_requested = False
self._certfile = certfile
self._keyfile = keyfile
super().__init__(*args, **kw)

def service_actions(self):
self._ready_event.set()

def wait_online(self):
self._ready_event.wait()

def stop(self):
self._stop_requested = True
self.shutdown()

def is_serving(self):
return not self._stop_requested

def get_request(self):
if self._certfile is None:
return super().get_request()
newsocket, fromaddr = self.socket.accept()
connstream = ssl.wrap_socket(
newsocket,
server_side=True,
certfile=self._certfile,
keyfile=self._keyfile,
ssl_version=ssl.PROTOCOL_TLSv1_2,
)
return connstream, fromaddr


class _RedisUDSServer(socketserver.UnixStreamServer):
def __init__(self, *args, **kw) -> None:
self._ready_event = threading.Event()
self._stop_requested = False
super().__init__(*args, **kw)

def service_actions(self):
self._ready_event.set()

def wait_online(self):
self._ready_event.wait()

def stop(self):
self._stop_requested = True
self.shutdown()

def is_serving(self):
return not self._stop_requested


class _RedisRequestHandler(socketserver.StreamRequestHandler):
def setup(self):
_logger.info("%s connected", self.client_address)

def finish(self):
_logger.info("%s disconnected", self.client_address)

def handle(self):
buffer = b""
command = None
command_ptr = None
fragment_length = None
while self.server.is_serving() or buffer:
try:
buffer += self.request.recv(1024)
except socket.timeout:
continue
if not buffer:
continue
parts = re.split(_CMD_SEP, buffer)
buffer = parts[-1]
for fragment in parts[:-1]:
fragment = fragment.decode()
_logger.info("Command fragment: %s", fragment)

if fragment.startswith("*") and command is None:
command = [None for _ in range(int(fragment[1:]))]
command_ptr = 0
fragment_length = None
continue

if fragment.startswith("$") and command[command_ptr] is None:
fragment_length = int(fragment[1:])
continue

assert len(fragment) == fragment_length
command[command_ptr] = fragment
command_ptr += 1

if command_ptr < len(command):
continue

command = " ".join(command)
_logger.info("Command %s", command)
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
_logger.info("Response %s", resp)
self.request.sendall(resp)
command = None
_logger.info("Exit handler")
Loading