Skip to content

Commit

Permalink
Tests for pool poisoning (encode#550)
Browse files Browse the repository at this point in the history
  • Loading branch information
igor.stulikov authored and igor.stulikov committed Jan 29, 2023
1 parent a406468 commit 7e288a4
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 14 deletions.
48 changes: 40 additions & 8 deletions httpcore/backends/mock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import ssl
import time
import typing
from typing import Optional
from typing import Optional, Type

from .._exceptions import ReadError
import anyio

from httpcore import ReadTimeout
from .base import AsyncNetworkBackend, AsyncNetworkStream, NetworkBackend, NetworkStream
from .._exceptions import ReadError


class MockSSLObject:
Expand Down Expand Up @@ -45,10 +49,24 @@ def get_extra_info(self, info: str) -> typing.Any:
return MockSSLObject(http2=self._http2) if info == "ssl_object" else None


class HangingStream(MockStream):
def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._closed:
raise ReadError("Connection closed")
time.sleep(timeout or 0.1)
raise ReadTimeout


class MockBackend(NetworkBackend):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
def __init__(
self,
buffer: typing.List[bytes],
http2: bool = False,
resp_stream_cls: Optional[Type[NetworkStream]] = None,
) -> None:
self._buffer = buffer
self._http2 = http2
self._resp_stream_cls: Type[MockStream] = resp_stream_cls or MockStream

def connect_tcp(
self,
Expand All @@ -57,12 +75,12 @@ def connect_tcp(
timeout: Optional[float] = None,
local_address: Optional[str] = None,
) -> NetworkStream:
return MockStream(list(self._buffer), http2=self._http2)
return self._resp_stream_cls(list(self._buffer), http2=self._http2)

def connect_unix_socket(
self, path: str, timeout: Optional[float] = None
) -> NetworkStream:
return MockStream(list(self._buffer), http2=self._http2)
return self._resp_stream_cls(list(self._buffer), http2=self._http2)

def sleep(self, seconds: float) -> None:
pass
Expand Down Expand Up @@ -99,10 +117,24 @@ def get_extra_info(self, info: str) -> typing.Any:
return MockSSLObject(http2=self._http2) if info == "ssl_object" else None


class AsyncHangingStream(AsyncMockStream):
async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._closed:
raise ReadError("Connection closed")
await anyio.sleep(timeout or 0.1)
raise ReadTimeout


class AsyncMockBackend(AsyncNetworkBackend):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
def __init__(
self,
buffer: typing.List[bytes],
http2: bool = False,
resp_stream_cls: Optional[Type[AsyncNetworkStream]] = None,
) -> None:
self._buffer = buffer
self._http2 = http2
self._resp_stream_cls: Type[AsyncMockStream] = resp_stream_cls or AsyncMockStream

async def connect_tcp(
self,
Expand All @@ -111,12 +143,12 @@ async def connect_tcp(
timeout: Optional[float] = None,
local_address: Optional[str] = None,
) -> AsyncNetworkStream:
return AsyncMockStream(list(self._buffer), http2=self._http2)
return self._resp_stream_cls(list(self._buffer), http2=self._http2)

async def connect_unix_socket(
self, path: str, timeout: Optional[float] = None
) -> AsyncNetworkStream:
return AsyncMockStream(list(self._buffer), http2=self._http2)
return self._resp_stream_cls(list(self._buffer), http2=self._http2)

async def sleep(self, seconds: float) -> None:
pass
81 changes: 79 additions & 2 deletions tests/_async/test_connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
import contextlib
from typing import List, Optional, Type

import pytest
import trio as concurrency
Expand All @@ -7,11 +8,12 @@
AsyncConnectionPool,
ConnectError,
PoolTimeout,
ReadTimeout,
ReadError,
UnsupportedProtocol,
)
from httpcore.backends.base import AsyncNetworkStream
from httpcore.backends.mock import AsyncMockBackend
from httpcore.backends.mock import AsyncMockBackend, AsyncHangingStream


@pytest.mark.anyio
Expand Down Expand Up @@ -502,6 +504,43 @@ async def test_connection_pool_timeout():
await pool.request("GET", "https://example.com/", extensions=extensions)


@pytest.mark.trio
async def test_pool_under_load():
"""
Pool must remain operational after some peak load.
"""
network_backend = AsyncMockBackend([], resp_stream_cls=AsyncHangingStream)

async def fetch(_pool: AsyncConnectionPool, *exceptions: Type[BaseException]):
with contextlib.suppress(*exceptions):
async with pool.stream(
"GET",
"http://a.com/",
extensions={
"timeout": {
"connect": 0.1,
"read": 0.1,
"pool": 0.1,
"write": 0.1,
},
},
) as response:
await response.aread()

async with AsyncConnectionPool(
max_connections=1, network_backend=network_backend
) as pool:
async with concurrency.open_nursery() as nursery:
for _ in range(300):
# Sending many requests to the same url. All of them but one will have PoolTimeout. One will
# be finished with ReadTimeout
nursery.start_soon(fetch, pool, PoolTimeout, ReadTimeout)
if pool.connections: # There is one connection in pool in "CONNECTING" state
assert pool.connections[0].is_connecting()
with pytest.raises(ReadTimeout): # ReadTimeout indicates that connection could be retrieved
await fetch(pool)


@pytest.mark.anyio
async def test_http11_upgrade_connection():
"""
Expand Down Expand Up @@ -534,3 +573,41 @@ async def test_http11_upgrade_connection():
network_stream = response.extensions["network_stream"]
content = await network_stream.read(max_bytes=1024)
assert content == b"..."


@pytest.mark.trio
async def test_pool_timeout_connection_cleanup():
"""
Test that pool cleans up connections after zero pool timeout. In case of stale
connection after timeout pool must not hang.
"""
network_backend = AsyncMockBackend(
[
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
b"\r\n",
b"Hello, world!",
] * 2,
)

async with AsyncConnectionPool(
network_backend=network_backend, max_connections=1
) as pool:
timeout = {
"connect": 0.1,
"read": 0.1,
"pool": 0,
"write": 0.1,
}
with contextlib.suppress(PoolTimeout):
await pool.request("GET", "https://example.com/", extensions={"timeout": timeout})

# wait for a considerable amount of time to make sure all requests time out
await concurrency.sleep(0.1)

await pool.request("GET", "https://example.com/", extensions={"timeout": {**timeout, 'pool': 0.1}})

if pool.connections:
for conn in pool.connections:
assert not conn.is_connecting()
85 changes: 81 additions & 4 deletions tests/_sync/test_connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import List, Optional
import contextlib
import time
from typing import List, Optional, Type

import pytest
from tests import concurrency

from httpcore import (
ConnectionPool,
ConnectError,
PoolTimeout,
ReadTimeout,
ReadError,
UnsupportedProtocol,
)
from httpcore.backends.base import NetworkStream
from httpcore.backends.mock import MockBackend

from httpcore.backends.mock import MockBackend, HangingStream
from tests import concurrency


def test_connection_pool_with_keepalive():
Expand Down Expand Up @@ -503,6 +505,81 @@ def test_connection_pool_timeout():



def test_pool_under_load():
"""
Pool must remain operational after some peak load.
"""
network_backend = MockBackend([], resp_stream_cls=HangingStream)

def fetch(_pool: ConnectionPool, *exceptions: Type[BaseException]):
with contextlib.suppress(*exceptions):
with pool.stream(
"GET",
"http://a.com/",
extensions={
"timeout": {
"connect": 0.1,
"read": 0.1,
"pool": 0.1,
"write": 0.1,
},
},
) as response:
response.read()

with ConnectionPool(
max_connections=1, network_backend=network_backend
) as pool:
with concurrency.open_nursery() as nursery:
for _ in range(300):
# Sending many requests to the same url. All of them but one will have PoolTimeout. One will
# be finished with ReadTimeout
nursery.start_soon(fetch, pool, PoolTimeout, ReadTimeout)
if pool.connections: # There is one connection in pool in "CONNECTING" state
assert pool.connections[0].is_connecting()
with pytest.raises(ReadTimeout): # ReadTimeout indicates that connection could be retrieved
fetch(pool)



def test_pool_timeout_connection_cleanup():
"""
Test that pool cleans up connections after zero pool timeout. In case of stale
connection after timeout pool must not hang.
"""
network_backend = MockBackend(
[
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
b"\r\n",
b"Hello, world!",
] * 2,
)

with ConnectionPool(
network_backend=network_backend, max_connections=2
) as pool:
timeout = {
"connect": 0.1,
"read": 0.1,
"pool": 0,
"write": 0.1,
}
with contextlib.suppress(PoolTimeout):
pool.request("GET", "https://example.com/", extensions={"timeout": timeout})

# wait for a considerable amount of time to make sure all requests time out
time.sleep(0.1)

pool.request("GET", "https://example.com/", extensions={"timeout": {**timeout, 'pool': 0.1}})

if pool.connections:
for conn in pool.connections:
assert not conn.is_connecting()



def test_http11_upgrade_connection():
"""
HTTP "101 Switching Protocols" indicates an upgraded connection.
Expand Down

0 comments on commit 7e288a4

Please sign in to comment.