-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add 'connect' tests for all Redis connection classes
- Loading branch information
1 parent
b167df0
commit f11ddae
Showing
1 changed file
with
176 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import logging | ||
import re | ||
import socket | ||
import ssl | ||
import threading | ||
|
||
import pytest | ||
|
||
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection | ||
|
||
from .ssl_certificates import get_ssl_certificate | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
_CLIENT_NAME = "test-suite-client" | ||
_CMD_SEP = b"\r\n" | ||
_SUCCESS_RESP = b"+OK" + _CMD_SEP | ||
_ERROR_RESP = b"-ERR" + _CMD_SEP | ||
_COMMANDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} | ||
|
||
|
||
@pytest.fixture | ||
def tcp_address(): | ||
with socket.socket() as sock: | ||
sock.bind(("127.0.0.1", 0)) | ||
return sock.getsockname() | ||
|
||
|
||
@pytest.fixture | ||
def uds_address(tmpdir): | ||
return tmpdir / "uds.sock" | ||
|
||
|
||
def test_tcp_connect(tcp_address): | ||
host, port = tcp_address | ||
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME) | ||
_assert_connect(conn, tcp_address) | ||
|
||
|
||
def test_uds_connect(uds_address): | ||
path = str(uds_address) | ||
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME) | ||
_assert_connect(conn, path) | ||
|
||
|
||
@pytest.mark.ssl | ||
def test_tcp_ssl_connect(tcp_address): | ||
host, port = tcp_address | ||
certfile = get_ssl_certificate("server-cert.pem") | ||
keyfile = get_ssl_certificate("server-key.pem") | ||
conn = SSLConnection( | ||
host=host, | ||
port=port, | ||
client_name=_CLIENT_NAME, | ||
ssl_certfile=certfile, | ||
ssl_keyfile=keyfile, | ||
ssl_ca_certs=certfile, | ||
) | ||
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) | ||
|
||
|
||
def _assert_connect(conn, server_address, certfile=None, keyfile=None): | ||
ready = threading.Event() | ||
stop = threading.Event() | ||
t = threading.Thread( | ||
target=_redis_mock_server, | ||
args=(server_address, ready, stop), | ||
kwargs={"certfile": certfile, "keyfile": keyfile}, | ||
) | ||
t.start() | ||
try: | ||
ready.wait() | ||
conn.connect() | ||
conn.disconnect() | ||
finally: | ||
stop.set() | ||
t.join(timeout=5) | ||
|
||
|
||
def _redis_mock_server(server_address, ready, stop, certfile=None, keyfile=None): | ||
try: | ||
if isinstance(server_address, str): | ||
family = socket.AF_UNIX | ||
mockname = "Redis mock server (UDS)" | ||
elif certfile: | ||
family = socket.AF_INET | ||
mockname = "Redis mock server (TCP-SSL)" | ||
else: | ||
family = socket.AF_INET | ||
mockname = "Redis mock server (TCP)" | ||
|
||
with socket.socket(family, socket.SOCK_STREAM) as s: | ||
s.bind(server_address) | ||
s.listen(1) | ||
s.settimeout(0.1) | ||
|
||
if certfile: | ||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | ||
context.minimum_version = ssl.TLSVersion.TLSv1_2 | ||
context.load_cert_chain(certfile=certfile, keyfile=keyfile) | ||
|
||
_logger.info("Start %s: %s", mockname, server_address) | ||
ready.set() | ||
|
||
# Wait a client connection | ||
while not stop.is_set(): | ||
try: | ||
sconn, _ = s.accept() | ||
sconn.settimeout(0.1) | ||
break | ||
except socket.timeout: | ||
pass | ||
if stop.is_set(): | ||
_logger.info("Exit %s: %s", mockname, server_address) | ||
return | ||
|
||
# Receive commands from the client | ||
with sconn: | ||
if certfile: | ||
conn = context.wrap_socket(sconn, server_side=True) | ||
else: | ||
conn = sconn | ||
try: | ||
buffer = b"" | ||
command = None | ||
command_ptr = None | ||
fragment_length = None | ||
while not stop.is_set() or buffer: | ||
try: | ||
buffer += conn.recv(1024) | ||
except socket.timeout: | ||
continue | ||
if not buffer: | ||
continue | ||
parts = re.split(_CMD_SEP, buffer) | ||
buffer = parts[-1] | ||
for fragment in parts[:-1]: | ||
fragment = fragment.decode() | ||
_logger.info( | ||
"Command fragment in %s: %s", mockname, fragment | ||
) | ||
|
||
if fragment.startswith("*") and command is None: | ||
command = [None for _ in range(int(fragment[1:]))] | ||
command_ptr = 0 | ||
fragment_length = None | ||
continue | ||
|
||
if ( | ||
fragment.startswith("$") | ||
and command[command_ptr] is None | ||
): | ||
fragment_length = int(fragment[1:]) | ||
continue | ||
|
||
assert len(fragment) == fragment_length | ||
command[command_ptr] = fragment | ||
command_ptr += 1 | ||
|
||
if command_ptr < len(command): | ||
continue | ||
|
||
command = " ".join(command) | ||
_logger.info("Command in %s: %s", mockname, command) | ||
resp = _COMMANDS.get(command, _ERROR_RESP) | ||
_logger.info("Response from %s: %s", mockname, resp) | ||
conn.sendall(resp) | ||
command = None | ||
finally: | ||
if certfile: | ||
conn.close() | ||
_logger.info("Exit %s: %s", mockname, server_address) | ||
except BaseException as e: | ||
_logger.exception("Error in %s: %s", mockname, e) | ||
raise |