-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tests: add 'connect' tests for all Redis connection classes
- Loading branch information
1 parent
0f0050c
commit 1344266
Showing
1 changed file
with
166 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import logging | ||
import re | ||
import socket | ||
import ssl | ||
import threading | ||
|
||
import pytest | ||
|
||
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection | ||
|
||
from .ssl_certificates import get_ssl_filename | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
_CLIENT_NAME = "test-suite-client" | ||
_CMD_SEP = b"\r\n" | ||
_SUCCESS_RESP = b"+OK" + _CMD_SEP | ||
_ERROR_RESP = b"-ERR" + _CMD_SEP | ||
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} | ||
|
||
|
||
@pytest.fixture | ||
def tcp_address(): | ||
with socket.socket() as sock: | ||
sock.bind(("127.0.0.1", 0)) | ||
return sock.getsockname() | ||
|
||
|
||
@pytest.fixture | ||
def uds_address(tmpdir): | ||
return tmpdir / "uds.sock" | ||
|
||
|
||
def test_tcp_connect(tcp_address): | ||
host, port = tcp_address | ||
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME) | ||
_assert_connect(conn, tcp_address) | ||
|
||
|
||
def test_uds_connect(uds_address): | ||
path = str(uds_address) | ||
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME) | ||
_assert_connect(conn, path) | ||
|
||
|
||
@pytest.mark.ssl | ||
def test_tcp_ssl_connect(tcp_address): | ||
host, port = tcp_address | ||
certfile = get_ssl_filename("server-cert.pem") | ||
keyfile = get_ssl_filename("server-key.pem") | ||
conn = SSLConnection( | ||
host=host, port=port, client_name=_CLIENT_NAME, ssl_ca_certs=certfile | ||
) | ||
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) | ||
|
||
|
||
def _assert_connect(conn, server_address, certfile=None, keyfile=None): | ||
ready = threading.Event() | ||
stop = threading.Event() | ||
t = threading.Thread( | ||
target=_redis_mock_server, | ||
args=(server_address, ready, stop), | ||
kwargs={"certfile": certfile, "keyfile": keyfile}, | ||
) | ||
t.start() | ||
try: | ||
ready.wait() | ||
conn.connect() | ||
conn.disconnect() | ||
finally: | ||
stop.set() | ||
t.join(timeout=5) | ||
|
||
|
||
def _redis_mock_server(server_address, ready, stop, certfile=None, keyfile=None): | ||
try: | ||
if isinstance(server_address, str): | ||
family = socket.AF_UNIX | ||
mockname = "Redis mock server (UDS)" | ||
elif certfile: | ||
family = socket.AF_INET | ||
mockname = "Redis mock server (TCP-SSL)" | ||
else: | ||
family = socket.AF_INET | ||
mockname = "Redis mock server (TCP)" | ||
|
||
with socket.socket(family, socket.SOCK_STREAM) as s: | ||
s.bind(server_address) | ||
s.listen(1) | ||
s.settimeout(0.1) | ||
|
||
if certfile: | ||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | ||
context.minimum_version = ssl.TLSVersion.TLSv1_2 | ||
context.load_cert_chain(certfile=certfile, keyfile=keyfile) | ||
|
||
_logger.info("Start %s: %s", mockname, server_address) | ||
ready.set() | ||
|
||
# Wait for a client connection | ||
while not stop.is_set(): | ||
try: | ||
sconn, _ = s.accept() | ||
sconn.settimeout(0.1) | ||
break | ||
except socket.timeout: | ||
pass | ||
if stop.is_set(): | ||
_logger.info("Exit %s: %s", mockname, server_address) | ||
return | ||
|
||
# Handle commands from the client | ||
with sconn: | ||
if certfile: | ||
with context.wrap_socket(sconn, server_side=True) as wconn: | ||
_redis_mock_server_handle(wconn, stop, mockname) | ||
else: | ||
_redis_mock_server_handle(sconn, stop, mockname) | ||
_logger.info("Exit %s: %s", mockname, server_address) | ||
except BaseException as e: | ||
_logger.exception("Error in %s: %s", mockname, e) | ||
raise | ||
|
||
|
||
def _redis_mock_server_handle(conn, stop, mockname): | ||
buffer = b"" | ||
command = None | ||
command_ptr = None | ||
fragment_length = None | ||
while not stop.is_set() or buffer: | ||
try: | ||
buffer += conn.recv(1024) | ||
except socket.timeout: | ||
continue | ||
if not buffer: | ||
continue | ||
parts = re.split(_CMD_SEP, buffer) | ||
buffer = parts[-1] | ||
for fragment in parts[:-1]: | ||
fragment = fragment.decode() | ||
_logger.info("Command fragment in %s: %s", mockname, fragment) | ||
|
||
if fragment.startswith("*") and command is None: | ||
command = [None for _ in range(int(fragment[1:]))] | ||
command_ptr = 0 | ||
fragment_length = None | ||
continue | ||
|
||
if fragment.startswith("$") and command[command_ptr] is None: | ||
fragment_length = int(fragment[1:]) | ||
continue | ||
|
||
assert len(fragment) == fragment_length | ||
command[command_ptr] = fragment | ||
command_ptr += 1 | ||
|
||
if command_ptr < len(command): | ||
continue | ||
|
||
command = " ".join(command) | ||
_logger.info("Command in %s: %s", mockname, command) | ||
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) | ||
_logger.info("Response from %s: %s", mockname, resp) | ||
conn.sendall(resp) | ||
command = None |