Skip to content

Commit

Permalink
Add integration testing for UDS case and nagle
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusSintonen committed Jun 17, 2024
1 parent aa31c2a commit 56447f4
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 23 deletions.
2 changes: 1 addition & 1 deletion httpcore/_backends/anyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def connect_unix_socket(
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream: # pragma: nocover
) -> AsyncNetworkStream:
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
Expand Down
2 changes: 1 addition & 1 deletion httpcore/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def connect_unix_socket(
path: str,
timeout: Optional[float] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream: # pragma: nocover
) -> AsyncNetworkStream:
exc_map: Dict[Type[Exception], Type[Exception]] = {
asyncio.TimeoutError: ConnectTimeout,
OSError: ConnectError,
Expand Down
26 changes: 15 additions & 11 deletions httpcore/_backends/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,6 @@ async def connect_tcp(
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
# By default for TCP sockets, trio enables TCP_NODELAY.
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
if socket_options is None:
socket_options = [] # pragma: no cover
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ConnectTimeout,
Expand All @@ -132,18 +128,15 @@ async def connect_tcp(
stream: trio.abc.Stream = await trio.open_tcp_stream(
host=host, port=port, local_address=local_address
)
for option in socket_options:
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
self._set_socket_options(stream, socket_options)
return TrioStream(stream)

async def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []
) -> AsyncNetworkStream:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ConnectTimeout,
Expand All @@ -153,9 +146,20 @@ async def connect_unix_socket(
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
stream: trio.abc.Stream = await trio.open_unix_socket(path)
for option in socket_options:
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
self._set_socket_options(stream, socket_options)
return TrioStream(stream)

async def sleep(self, seconds: float) -> None:
await trio.sleep(seconds) # pragma: nocover

def _set_socket_options(
self,
stream: trio.abc.Stream,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> None:
# By default for TCP sockets, trio enables TCP_NODELAY.
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
if not socket_options:
return
for option in socket_options:
stream.setsockopt(*option) # type: ignore[attr-defined]
37 changes: 35 additions & 2 deletions tests/_async/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import socket
import ssl
from tempfile import gettempdir

import pytest
import uvicorn
Expand Down Expand Up @@ -70,14 +72,26 @@ async def test_socket_options(
assert bool(opt) is keep_alive_enabled


@pytest.mark.anyio
async def test_socket_no_nagle(server: Server, server_url: str) -> None:
async with httpcore.AsyncConnectionPool() as pool:
response = await pool.request("GET", server_url)
assert response.status == 200

stream = response.extensions["network_stream"]
sock = stream.get_extra_info("socket")
opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
assert bool(opt) is True


@pytest.mark.anyio
async def test_pool_recovers_from_connection_breakage(
server_config: uvicorn.Config, server_url: str
) -> None:
async with httpcore.AsyncConnectionPool(
max_connections=1, max_keepalive_connections=1, keepalive_expiry=10
) as pool:
with Server(config=server_config).run_in_thread():
with Server(server_config).run_in_thread():
response = await pool.request("GET", server_url)
assert response.status == 200

Expand All @@ -91,7 +105,7 @@ async def test_pool_recovers_from_connection_breakage(
stream.get_extra_info("is_readable") is True
), "Should break by coming readable"

with Server(config=server_config).run_in_thread():
with Server(server_config).run_in_thread():
assert len(pool.connections) == 1
assert pool.connections[0] is conn, "Should be the broken connection"

Expand All @@ -100,3 +114,22 @@ async def test_pool_recovers_from_connection_breakage(

assert len(pool.connections) == 1
assert pool.connections[0] is not conn, "Should be a new connection"


@pytest.mark.anyio
async def test_unix_domain_socket(server_port, server_config, server_url):
uds = f"{gettempdir()}/test_httpcore_app.sock"
if os.path.exists(uds):
os.remove(uds) # pragma: nocover

uds_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
uds_sock.bind(uds)

with Server(server_config).run_in_thread(sockets=[uds_sock]):
async with httpcore.AsyncConnectionPool(uds=uds) as pool:
response = await pool.request("GET", server_url)
assert response.status == 200
finally:
uds_sock.close()
os.remove(uds)
37 changes: 35 additions & 2 deletions tests/_sync/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import socket
import ssl
from tempfile import gettempdir

import pytest
import uvicorn
Expand Down Expand Up @@ -71,13 +73,25 @@ def test_socket_options(



def test_socket_no_nagle(server: Server, server_url: str) -> None:
with httpcore.ConnectionPool() as pool:
response = pool.request("GET", server_url)
assert response.status == 200

stream = response.extensions["network_stream"]
sock = stream.get_extra_info("socket")
opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
assert bool(opt) is True



def test_pool_recovers_from_connection_breakage(
server_config: uvicorn.Config, server_url: str
) -> None:
with httpcore.ConnectionPool(
max_connections=1, max_keepalive_connections=1, keepalive_expiry=10
) as pool:
with Server(config=server_config).run_in_thread():
with Server(server_config).run_in_thread():
response = pool.request("GET", server_url)
assert response.status == 200

Expand All @@ -91,7 +105,7 @@ def test_pool_recovers_from_connection_breakage(
stream.get_extra_info("is_readable") is True
), "Should break by coming readable"

with Server(config=server_config).run_in_thread():
with Server(server_config).run_in_thread():
assert len(pool.connections) == 1
assert pool.connections[0] is conn, "Should be the broken connection"

Expand All @@ -100,3 +114,22 @@ def test_pool_recovers_from_connection_breakage(

assert len(pool.connections) == 1
assert pool.connections[0] is not conn, "Should be a new connection"



def test_unix_domain_socket(server_port, server_config, server_url):
uds = f"{gettempdir()}/test_httpcore_app.sock"
if os.path.exists(uds):
os.remove(uds) # pragma: nocover

uds_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
uds_sock.bind(uds)

with Server(server_config).run_in_thread(sockets=[uds_sock]):
with httpcore.ConnectionPool(uds=uds) as pool:
response = pool.request("GET", server_url)
assert response.status == 200
finally:
uds_sock.close()
os.remove(uds)
25 changes: 19 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import socket
import time
from contextlib import contextmanager
from threading import Thread
from typing import Generator, Iterator
from typing import Any, Awaitable, Callable, Generator, Iterator, List, Optional

import pytest
import uvicorn
Expand Down Expand Up @@ -31,12 +32,17 @@ def anyio_backend(request, monkeypatch):

class Server(uvicorn.Server):
@contextmanager
def run_in_thread(self) -> Generator[None, None, None]:
thread = Thread(target=self.run)
def run_in_thread(
self, sockets: Optional[List[socket.socket]] = None
) -> Generator[None, None, None]:
thread = Thread(target=lambda: self.run(sockets))
thread.start()
start_time = time.monotonic()
try:
while not self.started:
time.sleep(0.01)
if (time.monotonic() - start_time) > 5:
raise TimeoutError() # pragma: nocover
yield
finally:
self.should_exit = True
Expand All @@ -54,7 +60,7 @@ def server_url(server_port: int) -> str:


@pytest.fixture
def server_config(server_port: int) -> uvicorn.Config:
def server_app() -> Callable[[Any, Any, Any], Awaitable[None]]:
async def app(scope, receive, send):
assert scope["type"] == "http"
assert not (await receive()).get("more_body", False)
Expand All @@ -68,10 +74,17 @@ async def app(scope, receive, send):
await send(start)
await send(body)

return uvicorn.Config(app, port=server_port, log_level="error")
return app


@pytest.fixture
def server_config(
server_port: int, server_app: Callable[[Any, Any, Any], Awaitable[None]]
) -> uvicorn.Config:
return uvicorn.Config(server_app, port=server_port, log_level="error")


@pytest.fixture
def server(server_config: uvicorn.Config) -> Iterator[None]:
with Server(config=server_config).run_in_thread():
with Server(server_config).run_in_thread():
yield

0 comments on commit 56447f4

Please sign in to comment.