From 56447f4076d2fe31cf081dcebbea3c3d43270e36 Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Mon, 17 Jun 2024 12:34:48 +0300 Subject: [PATCH] Add integration testing for UDS case and nagle --- httpcore/_backends/anyio.py | 2 +- httpcore/_backends/asyncio.py | 2 +- httpcore/_backends/trio.py | 26 ++++++++++++---------- tests/_async/test_integration.py | 37 ++++++++++++++++++++++++++++++-- tests/_sync/test_integration.py | 37 ++++++++++++++++++++++++++++++-- tests/conftest.py | 25 +++++++++++++++------ 6 files changed, 106 insertions(+), 23 deletions(-) diff --git a/httpcore/_backends/anyio.py b/httpcore/_backends/anyio.py index 9c9d6f10..995a3d94 100644 --- a/httpcore/_backends/anyio.py +++ b/httpcore/_backends/anyio.py @@ -127,7 +127,7 @@ async def connect_unix_socket( path: str, timeout: typing.Optional[float] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, - ) -> AsyncNetworkStream: # pragma: nocover + ) -> AsyncNetworkStream: exc_map = { TimeoutError: ConnectTimeout, OSError: ConnectError, diff --git a/httpcore/_backends/asyncio.py b/httpcore/_backends/asyncio.py index 5154b17f..312fb648 100644 --- a/httpcore/_backends/asyncio.py +++ b/httpcore/_backends/asyncio.py @@ -194,7 +194,7 @@ async def connect_unix_socket( path: str, timeout: Optional[float] = None, socket_options: Optional[Iterable[SOCKET_OPTION]] = None, - ) -> AsyncNetworkStream: # pragma: nocover + ) -> AsyncNetworkStream: exc_map: Dict[Type[Exception], Type[Exception]] = { asyncio.TimeoutError: ConnectTimeout, OSError: ConnectError, diff --git a/httpcore/_backends/trio.py b/httpcore/_backends/trio.py index b1626d28..26320e61 100644 --- a/httpcore/_backends/trio.py +++ b/httpcore/_backends/trio.py @@ -117,10 +117,6 @@ async def connect_tcp( local_address: typing.Optional[str] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, ) -> AsyncNetworkStream: - # By default for TCP sockets, trio enables TCP_NODELAY. - # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream - if socket_options is None: - socket_options = [] # pragma: no cover timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { trio.TooSlowError: ConnectTimeout, @@ -132,8 +128,7 @@ async def connect_tcp( stream: trio.abc.Stream = await trio.open_tcp_stream( host=host, port=port, local_address=local_address ) - for option in socket_options: - stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + self._set_socket_options(stream, socket_options) return TrioStream(stream) async def connect_unix_socket( @@ -141,9 +136,7 @@ async def connect_unix_socket( path: str, timeout: typing.Optional[float] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, - ) -> AsyncNetworkStream: # pragma: nocover - if socket_options is None: - socket_options = [] + ) -> AsyncNetworkStream: timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { trio.TooSlowError: ConnectTimeout, @@ -153,9 +146,20 @@ async def connect_unix_socket( with map_exceptions(exc_map): with trio.fail_after(timeout_or_inf): stream: trio.abc.Stream = await trio.open_unix_socket(path) - for option in socket_options: - stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + self._set_socket_options(stream, socket_options) return TrioStream(stream) async def sleep(self, seconds: float) -> None: await trio.sleep(seconds) # pragma: nocover + + def _set_socket_options( + self, + stream: trio.abc.Stream, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> None: + # By default for TCP sockets, trio enables TCP_NODELAY. + # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream + if not socket_options: + return + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] diff --git a/tests/_async/test_integration.py b/tests/_async/test_integration.py index 14c03333..797933e4 100644 --- a/tests/_async/test_integration.py +++ b/tests/_async/test_integration.py @@ -1,5 +1,7 @@ +import os import socket import ssl +from tempfile import gettempdir import pytest import uvicorn @@ -70,6 +72,18 @@ async def test_socket_options( assert bool(opt) is keep_alive_enabled +@pytest.mark.anyio +async def test_socket_no_nagle(server: Server, server_url: str) -> None: + async with httpcore.AsyncConnectionPool() as pool: + response = await pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + assert bool(opt) is True + + @pytest.mark.anyio async def test_pool_recovers_from_connection_breakage( server_config: uvicorn.Config, server_url: str @@ -77,7 +91,7 @@ async def test_pool_recovers_from_connection_breakage( async with httpcore.AsyncConnectionPool( max_connections=1, max_keepalive_connections=1, keepalive_expiry=10 ) as pool: - with Server(config=server_config).run_in_thread(): + with Server(server_config).run_in_thread(): response = await pool.request("GET", server_url) assert response.status == 200 @@ -91,7 +105,7 @@ async def test_pool_recovers_from_connection_breakage( stream.get_extra_info("is_readable") is True ), "Should break by coming readable" - with Server(config=server_config).run_in_thread(): + with Server(server_config).run_in_thread(): assert len(pool.connections) == 1 assert pool.connections[0] is conn, "Should be the broken connection" @@ -100,3 +114,22 @@ async def test_pool_recovers_from_connection_breakage( assert len(pool.connections) == 1 assert pool.connections[0] is not conn, "Should be a new connection" + + +@pytest.mark.anyio +async def test_unix_domain_socket(server_port, server_config, server_url): + uds = f"{gettempdir()}/test_httpcore_app.sock" + if os.path.exists(uds): + os.remove(uds) # pragma: nocover + + uds_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + uds_sock.bind(uds) + + with Server(server_config).run_in_thread(sockets=[uds_sock]): + async with httpcore.AsyncConnectionPool(uds=uds) as pool: + response = await pool.request("GET", server_url) + assert response.status == 200 + finally: + uds_sock.close() + os.remove(uds) diff --git a/tests/_sync/test_integration.py b/tests/_sync/test_integration.py index 080c35ca..d114f878 100644 --- a/tests/_sync/test_integration.py +++ b/tests/_sync/test_integration.py @@ -1,5 +1,7 @@ +import os import socket import ssl +from tempfile import gettempdir import pytest import uvicorn @@ -71,13 +73,25 @@ def test_socket_options( +def test_socket_no_nagle(server: Server, server_url: str) -> None: + with httpcore.ConnectionPool() as pool: + response = pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + assert bool(opt) is True + + + def test_pool_recovers_from_connection_breakage( server_config: uvicorn.Config, server_url: str ) -> None: with httpcore.ConnectionPool( max_connections=1, max_keepalive_connections=1, keepalive_expiry=10 ) as pool: - with Server(config=server_config).run_in_thread(): + with Server(server_config).run_in_thread(): response = pool.request("GET", server_url) assert response.status == 200 @@ -91,7 +105,7 @@ def test_pool_recovers_from_connection_breakage( stream.get_extra_info("is_readable") is True ), "Should break by coming readable" - with Server(config=server_config).run_in_thread(): + with Server(server_config).run_in_thread(): assert len(pool.connections) == 1 assert pool.connections[0] is conn, "Should be the broken connection" @@ -100,3 +114,22 @@ def test_pool_recovers_from_connection_breakage( assert len(pool.connections) == 1 assert pool.connections[0] is not conn, "Should be a new connection" + + + +def test_unix_domain_socket(server_port, server_config, server_url): + uds = f"{gettempdir()}/test_httpcore_app.sock" + if os.path.exists(uds): + os.remove(uds) # pragma: nocover + + uds_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + uds_sock.bind(uds) + + with Server(server_config).run_in_thread(sockets=[uds_sock]): + with httpcore.ConnectionPool(uds=uds) as pool: + response = pool.request("GET", server_url) + assert response.status == 200 + finally: + uds_sock.close() + os.remove(uds) diff --git a/tests/conftest.py b/tests/conftest.py index abc541a8..5aa2ed50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ +import socket import time from contextlib import contextmanager from threading import Thread -from typing import Generator, Iterator +from typing import Any, Awaitable, Callable, Generator, Iterator, List, Optional import pytest import uvicorn @@ -31,12 +32,17 @@ def anyio_backend(request, monkeypatch): class Server(uvicorn.Server): @contextmanager - def run_in_thread(self) -> Generator[None, None, None]: - thread = Thread(target=self.run) + def run_in_thread( + self, sockets: Optional[List[socket.socket]] = None + ) -> Generator[None, None, None]: + thread = Thread(target=lambda: self.run(sockets)) thread.start() + start_time = time.monotonic() try: while not self.started: time.sleep(0.01) + if (time.monotonic() - start_time) > 5: + raise TimeoutError() # pragma: nocover yield finally: self.should_exit = True @@ -54,7 +60,7 @@ def server_url(server_port: int) -> str: @pytest.fixture -def server_config(server_port: int) -> uvicorn.Config: +def server_app() -> Callable[[Any, Any, Any], Awaitable[None]]: async def app(scope, receive, send): assert scope["type"] == "http" assert not (await receive()).get("more_body", False) @@ -68,10 +74,17 @@ async def app(scope, receive, send): await send(start) await send(body) - return uvicorn.Config(app, port=server_port, log_level="error") + return app + + +@pytest.fixture +def server_config( + server_port: int, server_app: Callable[[Any, Any, Any], Awaitable[None]] +) -> uvicorn.Config: + return uvicorn.Config(server_app, port=server_port, log_level="error") @pytest.fixture def server(server_config: uvicorn.Config) -> Iterator[None]: - with Server(config=server_config).run_in_thread(): + with Server(server_config).run_in_thread(): yield