diff --git a/tests/conftest.py b/tests/conftest.py index 5f4b6dd14d8..e33a99d1d66 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,12 @@ import asyncio -import hashlib -import pathlib -import shutil +import os +import socket import ssl import sys -import tempfile -import uuid +from hashlib import md5, sha256 +from pathlib import Path +from tempfile import TemporaryDirectory +from uuid import uuid4 import pytest @@ -24,15 +25,14 @@ pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] -@pytest.fixture -def shorttmpdir(): - # Provides a temporary directory with a shorter file system path than the - # tmpdir fixture. - tmpdir = pathlib.Path(tempfile.mkdtemp()) - yield tmpdir - # str(tmpdir) is required, Python 3.5 doesn't have __fspath__ - # concept - shutil.rmtree(str(tmpdir), ignore_errors=True) +IS_HPUX = sys.platform.startswith("hp-ux") +"""Specifies whether the current runtime is HP-UX.""" +IS_LINUX = sys.platform.startswith("linux") +"""Specifies whether the current runtime is HP-UX.""" +IS_UNIX = hasattr(socket, "AF_UNIX") +"""Specifies whether the current runtime is *NIX.""" + +needs_unix = pytest.mark.skipif(not IS_UNIX, reason="requires UNIX sockets") @pytest.fixture @@ -85,12 +85,91 @@ def tls_certificate_pem_bytes(tls_certificate): @pytest.fixture def tls_certificate_fingerprint_sha256(tls_certificate_pem_bytes): tls_cert_der = ssl.PEM_cert_to_DER_cert(tls_certificate_pem_bytes.decode()) - return hashlib.sha256(tls_cert_der).digest() + return sha256(tls_cert_der).digest() + + +@pytest.fixture +def unix_sockname(tmp_path, tmp_path_factory): + """Generate an fs path to the UNIX domain socket for testing. + + N.B. Different OS kernels have different fs path length limitations + for it. For Linux, it's 108, for HP-UX it's 92 (or higher) depending + on its version. For for most of the BSDs (Open, Free, macOS) it's + mostly 104 but sometimes it can be down to 100. + + Ref: https://github.com/aio-libs/aiohttp/issues/3572 + """ + if not IS_UNIX: + pytest.skip("requires UNIX sockets") + + max_sock_len = 92 if IS_HPUX else 108 if IS_LINUX else 100 + """Amount of bytes allocated for the UNIX socket path by OS kernel. + + Ref: https://unix.stackexchange.com/a/367012/27133 + """ + + sock_file_name = "unix.sock" + unique_prefix = f"{uuid4()!s}-" + unique_prefix_len = len(unique_prefix.encode()) + + root_tmp_dir = Path("/tmp").resolve() + os_tmp_dir = Path(os.getenv("TMPDIR", "/tmp")).resolve() + original_base_tmp_path = Path( + str(tmp_path_factory.getbasetemp()), + ).resolve() + + original_base_tmp_path_hash = md5( + str(original_base_tmp_path).encode(), + ).hexdigest() + + def make_tmp_dir(base_tmp_dir): + return TemporaryDirectory( + dir=str(base_tmp_dir), + prefix="pt-", + suffix=f"-{original_base_tmp_path_hash!s}", + ) + + def assert_sock_fits(sock_path): + sock_path_len = len(sock_path.encode()) + # exit-check to verify that it's correct and simplify debugging + # in the future + assert sock_path_len <= max_sock_len, ( + "Suggested UNIX socket ({sock_path}) is {sock_path_len} bytes " + "long but the current kernel only has {max_sock_len} bytes " + "allocated to hold it so it must be shorter. " + "See https://github.com/aio-libs/aiohttp/issues/3572 " + "for more info." + ).format_map(locals()) + + paths = original_base_tmp_path, os_tmp_dir, root_tmp_dir + unique_paths = [p for n, p in enumerate(paths) if p not in paths[:n]] + paths_num = len(unique_paths) + + for num, tmp_dir_path in enumerate(paths, 1): + with make_tmp_dir(tmp_dir_path) as tmpd: + tmpd = Path(tmpd).resolve() + sock_path = str(tmpd / sock_file_name) + sock_path_len = len(sock_path.encode()) + + if num >= paths_num: + # exit-check to verify that it's correct and simplify + # debugging in the future + assert_sock_fits(sock_path) + + if sock_path_len <= max_sock_len: + if max_sock_len - sock_path_len >= unique_prefix_len: + # If we're lucky to have extra space in the path, + # let's also make it more unique + sock_path = str(tmpd / "".join((unique_prefix, sock_file_name))) + # Double-checking it: + assert_sock_fits(sock_path) + yield sock_path + return @pytest.fixture def pipe_name(): - name = rf"\\.\pipe\{uuid.uuid4().hex}" + name = rf"\\.\pipe\{uuid4().hex}" return name diff --git a/tests/test_connector.py b/tests/test_connector.py index f2b14648557..130f92fe130 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -12,6 +12,7 @@ from unittest import mock import pytest +from conftest import needs_unix from yarl import URL import aiohttp @@ -43,12 +44,6 @@ def ssl_key(): return ConnectionKey("localhost", 80, True, None, None, None, None) -@pytest.fixture -def unix_sockname(shorttmpdir): - sock_path = shorttmpdir / "socket.sock" - return str(sock_path) - - @pytest.fixture def unix_server(loop, unix_sockname): runners = [] @@ -1918,7 +1913,7 @@ async def handler(request): assert r.status == 200 -@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires unix socket") +@needs_unix async def test_unix_connector_not_found(loop) -> None: connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex, loop=loop) @@ -1927,7 +1922,7 @@ async def test_unix_connector_not_found(loop) -> None: await connector.connect(req, None, ClientTimeout()) -@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires unix socket") +@needs_unix async def test_unix_connector_permission(loop) -> None: loop.create_unix_connection = make_mocked_coro(raise_exception=PermissionError()) connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex, loop=loop) @@ -2086,7 +2081,6 @@ async def handler(request): await conn.close() -@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") async def test_unix_connector(unix_server, unix_sockname) -> None: async def handler(request): return web.Response() diff --git a/tests/test_run_app.py b/tests/test_run_app.py index 44b27132605..1a2b9083fed 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest +from conftest import IS_UNIX, needs_unix from aiohttp import web from aiohttp.helpers import PY_37 @@ -19,8 +20,7 @@ from aiohttp.web_runner import BaseRunner # Test for features of OS' socket support -_has_unix_domain_socks = hasattr(socket, "AF_UNIX") -if _has_unix_domain_socks: +if IS_UNIX: _abstract_path_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: _abstract_path_sock.bind(b"\x00" + uuid4().hex.encode("ascii")) # type: ignore @@ -37,10 +37,7 @@ skip_if_no_abstract_paths = pytest.mark.skipif( _abstract_path_failed, reason="Linux-style abstract paths are not supported." ) -skip_if_no_unix_socks = pytest.mark.skipif( - not _has_unix_domain_socks, reason="Unix domain sockets are not supported" -) -del _has_unix_domain_socks, _abstract_path_failed +del IS_UNIX, _abstract_path_failed HAS_IPV6 = socket.has_ipv6 if HAS_IPV6: @@ -509,38 +506,38 @@ def test_run_app_custom_backlog_unix(patched_loop) -> None: ) -@skip_if_no_unix_socks -def test_run_app_http_unix_socket(patched_loop, shorttmpdir) -> None: +def test_run_app_http_unix_socket(patched_loop, unix_sockname) -> None: app = web.Application() - sock_path = str(shorttmpdir / "socket.sock") printer = mock.Mock(wraps=stopper(patched_loop)) - web.run_app(app, path=sock_path, print=printer, loop=patched_loop) + web.run_app(app, path=unix_sockname, print=printer, loop=patched_loop) patched_loop.create_unix_server.assert_called_with( - mock.ANY, sock_path, ssl=None, backlog=128 + mock.ANY, unix_sockname, ssl=None, backlog=128 ) - assert f"http://unix:{sock_path}:" in printer.call_args[0][0] + assert f"http://unix:{unix_sockname}:" in printer.call_args[0][0] -@skip_if_no_unix_socks -def test_run_app_https_unix_socket(patched_loop, shorttmpdir) -> None: +def test_run_app_https_unix_socket(patched_loop, unix_sockname) -> None: app = web.Application() - sock_path = str(shorttmpdir / "socket.sock") ssl_context = ssl.create_default_context() printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app( - app, path=sock_path, ssl_context=ssl_context, print=printer, loop=patched_loop + app, + path=unix_sockname, + ssl_context=ssl_context, + print=printer, + loop=patched_loop, ) patched_loop.create_unix_server.assert_called_with( - mock.ANY, sock_path, ssl=ssl_context, backlog=128 + mock.ANY, unix_sockname, ssl=ssl_context, backlog=128 ) - assert f"https://unix:{sock_path}:" in printer.call_args[0][0] + assert f"https://unix:{unix_sockname}:" in printer.call_args[0][0] -@skip_if_no_unix_socks +@needs_unix @skip_if_no_abstract_paths def test_run_app_abstract_linux_socket(patched_loop) -> None: sock_path = b"\x00" + uuid4().hex.encode("ascii") @@ -592,7 +589,7 @@ def test_run_app_preexisting_inet6_socket(patched_loop) -> None: assert f"http://[::]:{port}" in printer.call_args[0][0] -@skip_if_no_unix_socks +@needs_unix def test_run_app_preexisting_unix_socket(patched_loop, mocker) -> None: app = web.Application() diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index 8c08a5f5fbd..b8a738e8454 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -105,21 +105,17 @@ def test_non_app() -> None: web.AppRunner(object()) -@pytest.mark.skipif( - platform.system() == "Windows", reason="Unix socket support is required" -) -async def test_addresses(make_runner, shorttmpdir) -> None: +async def test_addresses(make_runner, unix_sockname) -> None: _sock = get_unused_port_socket("127.0.0.1") runner = make_runner() await runner.setup() tcp = web.SockSite(runner, _sock) await tcp.start() - path = str(shorttmpdir / "tmp.sock") - unix = web.UnixSite(runner, path) + unix = web.UnixSite(runner, unix_sockname) await unix.start() actual_addrs = runner.addresses expected_host, expected_post = _sock.getsockname()[:2] - assert actual_addrs == [(expected_host, expected_post), path] + assert actual_addrs == [(expected_host, expected_post), unix_sockname] @pytest.mark.skipif(