diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index 38f4a292..5d6ee440 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -1,6 +1,5 @@ import enum import logging -import random import ssl import time from types import TracebackType @@ -57,12 +56,10 @@ def __init__( origin: Origin, stream: AsyncNetworkStream, keepalive_expiry: Optional[float] = None, - socket_poll_interval_between: Tuple[float, float] = (1, 3), ) -> None: self._origin = origin self._network_stream = stream self._keepalive_expiry = keepalive_expiry - self._socket_poll_interval_between = socket_poll_interval_between self._expire_at: Optional[float] = None self._state = HTTPConnectionState.NEW self._state_lock = AsyncLock() @@ -82,6 +79,16 @@ async def handle_async_request(self, request: Request) -> Response: ) async with self._state_lock: + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + if server_disconnected: + raise ConnectionNotAvailable() + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): self._request_count += 1 self._state = HTTPConnectionState.ACTIVE @@ -287,29 +294,7 @@ def is_available(self) -> bool: def has_expired(self) -> bool: now = time.monotonic() - keepalive_expired = self._expire_at is not None and now > self._expire_at - if keepalive_expired: - return True - - # If the HTTP connection is idle but the socket is readable, then the - # only valid state is that the socket is about to return b"", indicating - # a server-initiated disconnect. - # Checking the readable status is relatively expensive so check it at a lower frequency. - if (now - self._network_stream_used_at) > self._socket_poll_interval(): - self._network_stream_used_at = now - server_disconnected = ( - self._state == HTTPConnectionState.IDLE - and self._network_stream.get_extra_info("is_readable") - ) - if server_disconnected: - return True - - return False - - def _socket_poll_interval(self) -> float: - # Randomize to avoid polling for all the connections at once - low, high = self._socket_poll_interval_between - return random.uniform(low, high) + return self._expire_at is not None and now > self._expire_at def is_idle(self) -> bool: return self._state == HTTPConnectionState.IDLE diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index eecfd33c..33342fb6 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -1,6 +1,5 @@ import enum import logging -import random import ssl import time from types import TracebackType @@ -57,12 +56,10 @@ def __init__( origin: Origin, stream: NetworkStream, keepalive_expiry: Optional[float] = None, - socket_poll_interval_between: Tuple[float, float] = (1, 3), ) -> None: self._origin = origin self._network_stream = stream self._keepalive_expiry = keepalive_expiry - self._socket_poll_interval_between = socket_poll_interval_between self._expire_at: Optional[float] = None self._state = HTTPConnectionState.NEW self._state_lock = Lock() @@ -82,6 +79,16 @@ def handle_request(self, request: Request) -> Response: ) with self._state_lock: + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + if server_disconnected: + raise ConnectionNotAvailable() + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): self._request_count += 1 self._state = HTTPConnectionState.ACTIVE @@ -287,29 +294,7 @@ def is_available(self) -> bool: def has_expired(self) -> bool: now = time.monotonic() - keepalive_expired = self._expire_at is not None and now > self._expire_at - if keepalive_expired: - return True - - # If the HTTP connection is idle but the socket is readable, then the - # only valid state is that the socket is about to return b"", indicating - # a server-initiated disconnect. - # Checking the readable status is relatively expensive so check it at a lower frequency. - if (now - self._network_stream_used_at) > self._socket_poll_interval(): - self._network_stream_used_at = now - server_disconnected = ( - self._state == HTTPConnectionState.IDLE - and self._network_stream.get_extra_info("is_readable") - ) - if server_disconnected: - return True - - return False - - def _socket_poll_interval(self) -> float: - # Randomize to avoid polling for all the connections at once - low, high = self._socket_poll_interval_between - return random.uniform(low, high) + return self._expire_at is not None and now > self._expire_at def is_idle(self) -> bool: return self._state == HTTPConnectionState.IDLE diff --git a/tests/_async/test_http11.py b/tests/_async/test_http11.py index cb275e1a..94f2febf 100644 --- a/tests/_async/test_http11.py +++ b/tests/_async/test_http11.py @@ -1,5 +1,3 @@ -from typing import Any, List - import pytest import httpcore @@ -18,10 +16,7 @@ async def test_http11_connection(): ] ) async with httpcore.AsyncHTTP11Connection( - origin=origin, - stream=stream, - keepalive_expiry=5.0, - socket_poll_interval_between=(0, 0), + origin=origin, stream=stream, keepalive_expiry=5.0 ) as conn: response = await conn.request("GET", "https://example.com/") assert response.status == 200 @@ -53,9 +48,7 @@ async def test_http11_connection_unread_response(): b"Hello, world!", ] ) - async with httpcore.AsyncHTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: + async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn: async with conn.stream("GET", "https://example.com/") as response: assert response.status == 200 @@ -77,9 +70,7 @@ async def test_http11_connection_with_remote_protocol_error(): """ origin = httpcore.Origin(b"https", b"example.com", 443) stream = httpcore.AsyncMockStream([b"Wait, this isn't valid HTTP!", b""]) - async with httpcore.AsyncHTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: + async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn: with pytest.raises(httpcore.RemoteProtocolError): await conn.request("GET", "https://example.com/") @@ -108,9 +99,7 @@ async def test_http11_connection_with_incomplete_response(): b"Hello, wor", ] ) - async with httpcore.AsyncHTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: + async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn: with pytest.raises(httpcore.RemoteProtocolError): await conn.request("GET", "https://example.com/") @@ -140,9 +129,7 @@ async def test_http11_connection_with_local_protocol_error(): b"Hello, world!", ] ) - async with httpcore.AsyncHTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: + async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn: with pytest.raises(httpcore.LocalProtocolError) as exc_info: await conn.request("GET", "https://example.com/", headers={"Host": "\0"}) @@ -158,85 +145,6 @@ async def test_http11_connection_with_local_protocol_error(): ) -@pytest.mark.anyio -async def test_http11_has_expired_checks_readable_status(): - class AsyncMockStreamReadable(httpcore.AsyncMockStream): - def __init__(self, buffer: List[bytes]) -> None: - super().__init__(buffer) - self.is_readable = False - self.checks = 0 - - def get_extra_info(self, info: str) -> Any: - if info == "is_readable": - self.checks += 1 - return self.is_readable - return super().get_extra_info(info) # pragma: nocover - - origin = httpcore.Origin(b"https", b"example.com", 443) - stream = AsyncMockStreamReadable( - [ - b"HTTP/1.1 200 OK\r\n", - b"Content-Type: plain/text\r\n", - b"Content-Length: 13\r\n", - b"\r\n", - b"Hello, world!", - ] - ) - async with httpcore.AsyncHTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: - response = await conn.request("GET", "https://example.com/") - assert response.status == 200 - - assert stream.checks == 0 - assert not conn.has_expired() - stream.is_readable = True - assert conn.has_expired() - assert stream.checks == 2 - - -@pytest.mark.anyio -@pytest.mark.parametrize("should_check", [True, False]) -async def test_http11_has_expired_checks_readable_status_by_interval( - monkeypatch, should_check -): - origin = httpcore.Origin(b"https", b"example.com", 443) - stream = httpcore.AsyncMockStream( - [ - b"HTTP/1.1 200 OK\r\n", - b"Content-Type: plain/text\r\n", - b"Content-Length: 13\r\n", - b"\r\n", - b"Hello, world!", - ] - ) - async with httpcore.AsyncHTTP11Connection( - origin=origin, - stream=stream, - keepalive_expiry=5.0, - socket_poll_interval_between=(0, 0) if should_check else (999, 999), - ) as conn: - orig = conn._network_stream.get_extra_info - calls = [] - - def patch_get_extra_info(attr_name: str) -> Any: - calls.append(attr_name) - return orig(attr_name) - - monkeypatch.setattr( - conn._network_stream, "get_extra_info", patch_get_extra_info - ) - - response = await conn.request("GET", "https://example.com/") - assert response.status == 200 - - assert "is_readable" not in calls - assert not conn.has_expired() - assert ( - ("is_readable" in calls) if should_check else ("is_readable" not in calls) - ) - - @pytest.mark.anyio async def test_http11_connection_handles_one_active_request(): """ diff --git a/tests/_sync/test_http11.py b/tests/_sync/test_http11.py index a870865b..f2fa28f4 100644 --- a/tests/_sync/test_http11.py +++ b/tests/_sync/test_http11.py @@ -1,5 +1,3 @@ -from typing import Any, List - import pytest import httpcore @@ -18,10 +16,7 @@ def test_http11_connection(): ] ) with httpcore.HTTP11Connection( - origin=origin, - stream=stream, - keepalive_expiry=5.0, - socket_poll_interval_between=(0, 0), + origin=origin, stream=stream, keepalive_expiry=5.0 ) as conn: response = conn.request("GET", "https://example.com/") assert response.status == 200 @@ -53,9 +48,7 @@ def test_http11_connection_unread_response(): b"Hello, world!", ] ) - with httpcore.HTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: + with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn: with conn.stream("GET", "https://example.com/") as response: assert response.status == 200 @@ -77,9 +70,7 @@ def test_http11_connection_with_remote_protocol_error(): """ origin = httpcore.Origin(b"https", b"example.com", 443) stream = httpcore.MockStream([b"Wait, this isn't valid HTTP!", b""]) - with httpcore.HTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: + with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn: with pytest.raises(httpcore.RemoteProtocolError): conn.request("GET", "https://example.com/") @@ -108,9 +99,7 @@ def test_http11_connection_with_incomplete_response(): b"Hello, wor", ] ) - with httpcore.HTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: + with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn: with pytest.raises(httpcore.RemoteProtocolError): conn.request("GET", "https://example.com/") @@ -140,9 +129,7 @@ def test_http11_connection_with_local_protocol_error(): b"Hello, world!", ] ) - with httpcore.HTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: + with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn: with pytest.raises(httpcore.LocalProtocolError) as exc_info: conn.request("GET", "https://example.com/", headers={"Host": "\0"}) @@ -159,85 +146,6 @@ def test_http11_connection_with_local_protocol_error(): -def test_http11_has_expired_checks_readable_status(): - class MockStreamReadable(httpcore.MockStream): - def __init__(self, buffer: List[bytes]) -> None: - super().__init__(buffer) - self.is_readable = False - self.checks = 0 - - def get_extra_info(self, info: str) -> Any: - if info == "is_readable": - self.checks += 1 - return self.is_readable - return super().get_extra_info(info) # pragma: nocover - - origin = httpcore.Origin(b"https", b"example.com", 443) - stream = MockStreamReadable( - [ - b"HTTP/1.1 200 OK\r\n", - b"Content-Type: plain/text\r\n", - b"Content-Length: 13\r\n", - b"\r\n", - b"Hello, world!", - ] - ) - with httpcore.HTTP11Connection( - origin=origin, stream=stream, socket_poll_interval_between=(0, 0) - ) as conn: - response = conn.request("GET", "https://example.com/") - assert response.status == 200 - - assert stream.checks == 0 - assert not conn.has_expired() - stream.is_readable = True - assert conn.has_expired() - assert stream.checks == 2 - - - -@pytest.mark.parametrize("should_check", [True, False]) -def test_http11_has_expired_checks_readable_status_by_interval( - monkeypatch, should_check -): - origin = httpcore.Origin(b"https", b"example.com", 443) - stream = httpcore.MockStream( - [ - b"HTTP/1.1 200 OK\r\n", - b"Content-Type: plain/text\r\n", - b"Content-Length: 13\r\n", - b"\r\n", - b"Hello, world!", - ] - ) - with httpcore.HTTP11Connection( - origin=origin, - stream=stream, - keepalive_expiry=5.0, - socket_poll_interval_between=(0, 0) if should_check else (999, 999), - ) as conn: - orig = conn._network_stream.get_extra_info - calls = [] - - def patch_get_extra_info(attr_name: str) -> Any: - calls.append(attr_name) - return orig(attr_name) - - monkeypatch.setattr( - conn._network_stream, "get_extra_info", patch_get_extra_info - ) - - response = conn.request("GET", "https://example.com/") - assert response.status == 200 - - assert "is_readable" not in calls - assert not conn.has_expired() - assert ( - ("is_readable" in calls) if should_check else ("is_readable" not in calls) - ) - - - def test_http11_connection_handles_one_active_request(): """ Attempting to send a request while one is already in-flight will raise