diff --git a/httpcore/backends/mock.py b/httpcore/backends/mock.py index 9aba0ead..871a2521 100644 --- a/httpcore/backends/mock.py +++ b/httpcore/backends/mock.py @@ -1,9 +1,13 @@ import ssl +import time import typing -from typing import Optional +from typing import Optional, Type -from .._exceptions import ReadError +import anyio + +from httpcore import ReadTimeout from .base import AsyncNetworkBackend, AsyncNetworkStream, NetworkBackend, NetworkStream +from .._exceptions import ReadError class MockSSLObject: @@ -45,10 +49,24 @@ def get_extra_info(self, info: str) -> typing.Any: return MockSSLObject(http2=self._http2) if info == "ssl_object" else None +class HangingStream(MockStream): + def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + if self._closed: + raise ReadError("Connection closed") + time.sleep(timeout or 0.1) + raise ReadTimeout + + class MockBackend(NetworkBackend): - def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + def __init__( + self, + buffer: typing.List[bytes], + http2: bool = False, + resp_stream_cls: Optional[Type[NetworkStream]] = None, + ) -> None: self._buffer = buffer self._http2 = http2 + self._resp_stream_cls: Type[MockStream] = resp_stream_cls or MockStream def connect_tcp( self, @@ -57,12 +75,12 @@ def connect_tcp( timeout: Optional[float] = None, local_address: Optional[str] = None, ) -> NetworkStream: - return MockStream(list(self._buffer), http2=self._http2) + return self._resp_stream_cls(list(self._buffer), http2=self._http2) def connect_unix_socket( self, path: str, timeout: Optional[float] = None ) -> NetworkStream: - return MockStream(list(self._buffer), http2=self._http2) + return self._resp_stream_cls(list(self._buffer), http2=self._http2) def sleep(self, seconds: float) -> None: pass @@ -99,10 +117,24 @@ def get_extra_info(self, info: str) -> typing.Any: return MockSSLObject(http2=self._http2) if info == "ssl_object" else None +class AsyncHangingStream(AsyncMockStream): + async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + if self._closed: + raise ReadError("Connection closed") + await anyio.sleep(timeout or 0.1) + raise ReadTimeout + + class AsyncMockBackend(AsyncNetworkBackend): - def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + def __init__( + self, + buffer: typing.List[bytes], + http2: bool = False, + resp_stream_cls: Optional[Type[AsyncNetworkStream]] = None, + ) -> None: self._buffer = buffer self._http2 = http2 + self._resp_stream_cls: Type[AsyncMockStream] = resp_stream_cls or AsyncMockStream async def connect_tcp( self, @@ -111,12 +143,12 @@ async def connect_tcp( timeout: Optional[float] = None, local_address: Optional[str] = None, ) -> AsyncNetworkStream: - return AsyncMockStream(list(self._buffer), http2=self._http2) + return self._resp_stream_cls(list(self._buffer), http2=self._http2) async def connect_unix_socket( self, path: str, timeout: Optional[float] = None ) -> AsyncNetworkStream: - return AsyncMockStream(list(self._buffer), http2=self._http2) + return self._resp_stream_cls(list(self._buffer), http2=self._http2) async def sleep(self, seconds: float) -> None: pass diff --git a/tests/_async/test_connection_pool.py b/tests/_async/test_connection_pool.py index d2ac58a4..6c45b020 100644 --- a/tests/_async/test_connection_pool.py +++ b/tests/_async/test_connection_pool.py @@ -1,4 +1,5 @@ -from typing import List, Optional +import contextlib +from typing import List, Optional, Type import pytest import trio as concurrency @@ -7,11 +8,12 @@ AsyncConnectionPool, ConnectError, PoolTimeout, + ReadTimeout, ReadError, UnsupportedProtocol, ) from httpcore.backends.base import AsyncNetworkStream -from httpcore.backends.mock import AsyncMockBackend +from httpcore.backends.mock import AsyncMockBackend, AsyncHangingStream @pytest.mark.anyio @@ -502,6 +504,82 @@ async def test_connection_pool_timeout(): await pool.request("GET", "https://example.com/", extensions=extensions) +@pytest.mark.trio +async def test_pool_under_load(): + """ + Pool must remain operational after some peak load. + """ + network_backend = AsyncMockBackend([], resp_stream_cls=AsyncHangingStream) + + async def fetch(_pool: AsyncConnectionPool, *exceptions: Type[BaseException]): + with contextlib.suppress(*exceptions): + async with pool.stream( + "GET", + "http://a.com/", + extensions={ + "timeout": { + "connect": 0.1, + "read": 0.1, + "pool": 0.1, + "write": 0.1, + }, + }, + ) as response: + await response.aread() + + async with AsyncConnectionPool( + max_connections=1, network_backend=network_backend + ) as pool: + async with concurrency.open_nursery() as nursery: + for _ in range(300): + # Sending many requests to the same url. All of them but one will have PoolTimeout. One will + # be finished with ReadTimeout + nursery.start_soon(fetch, pool, PoolTimeout, ReadTimeout) + if pool.connections: # There is one connection in pool in "CONNECTING" state + assert pool.connections[0].is_connecting() + with pytest.raises(ReadTimeout): # ReadTimeout indicates that connection could be retrieved + await fetch(pool) + + + +@pytest.mark.trio +async def test_pool_timeout_connection_cleanup(): + """ + Test that pool cleans up connections after zero pool timeout. In case of stale + connection after timeout pool must not hang. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] * 2, + ) + + async with AsyncConnectionPool( + network_backend=network_backend, max_connections=1 + ) as pool: + timeout = { + "connect": 0.1, + "read": 0.1, + "pool": 0, + "write": 0.1, + } + with contextlib.suppress(PoolTimeout): + await pool.request("GET", "https://example.com/", extensions={"timeout": timeout}) + + # wait for a considerable amount of time to make sure all requests time out + await concurrency.sleep(0.1) + + await pool.request("GET", "https://example.com/", extensions={"timeout": {**timeout, 'pool': 0.1}}) + + if pool.connections: + for conn in pool.connections: + assert not conn.is_connecting() + + @pytest.mark.anyio async def test_http11_upgrade_connection(): """ diff --git a/tests/_sync/test_connection_pool.py b/tests/_sync/test_connection_pool.py index 453b7fdc..1384bef0 100644 --- a/tests/_sync/test_connection_pool.py +++ b/tests/_sync/test_connection_pool.py @@ -1,18 +1,20 @@ -from typing import List, Optional +import contextlib +import time +from typing import List, Optional, Type import pytest -from tests import concurrency from httpcore import ( ConnectionPool, ConnectError, PoolTimeout, + ReadTimeout, ReadError, UnsupportedProtocol, ) from httpcore.backends.base import NetworkStream -from httpcore.backends.mock import MockBackend - +from httpcore.backends.mock import MockBackend, HangingStream +from tests import concurrency def test_connection_pool_with_keepalive(): @@ -503,6 +505,81 @@ def test_connection_pool_timeout(): +def test_pool_under_load(): + """ + Pool must remain operational after some peak load. + """ + network_backend = MockBackend([], resp_stream_cls=HangingStream) + + def fetch(_pool: ConnectionPool, *exceptions: Type[BaseException]): + with contextlib.suppress(*exceptions): + with pool.stream( + "GET", + "http://a.com/", + extensions={ + "timeout": { + "connect": 0.1, + "read": 0.1, + "pool": 0.1, + "write": 0.1, + }, + }, + ) as response: + response.read() + + with ConnectionPool( + max_connections=1, network_backend=network_backend + ) as pool: + with concurrency.open_nursery() as nursery: + for _ in range(300): + # Sending many requests to the same url. All of them but one will have PoolTimeout. One will + # be finished with ReadTimeout + nursery.start_soon(fetch, pool, PoolTimeout, ReadTimeout) + if pool.connections: # There is one connection in pool in "CONNECTING" state + assert pool.connections[0].is_connecting() + with pytest.raises(ReadTimeout): # ReadTimeout indicates that connection could be retrieved + fetch(pool) + + + +def test_pool_timeout_connection_cleanup(): + """ + Test that pool cleans up connections after zero pool timeout. In case of stale + connection after timeout pool must not hang. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] * 2, + ) + + with ConnectionPool( + network_backend=network_backend, max_connections=2 + ) as pool: + timeout = { + "connect": 0.1, + "read": 0.1, + "pool": 0, + "write": 0.1, + } + with contextlib.suppress(PoolTimeout): + pool.request("GET", "https://example.com/", extensions={"timeout": timeout}) + + # wait for a considerable amount of time to make sure all requests time out + time.sleep(0.1) + + pool.request("GET", "https://example.com/", extensions={"timeout": {**timeout, 'pool': 0.1}}) + + if pool.connections: + for conn in pool.connections: + assert not conn.is_connecting() + + + def test_http11_upgrade_connection(): """ HTTP "101 Switching Protocols" indicates an upgraded connection.