Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusSintonen committed Jun 15, 2024
1 parent 4254b13 commit 8b99090
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 246 deletions.
37 changes: 11 additions & 26 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import enum
import logging
import random
import ssl
import time
from types import TracebackType
Expand Down Expand Up @@ -57,12 +56,10 @@ def __init__(
origin: Origin,
stream: AsyncNetworkStream,
keepalive_expiry: Optional[float] = None,
socket_poll_interval_between: Tuple[float, float] = (1, 3),
) -> None:
self._origin = origin
self._network_stream = stream
self._keepalive_expiry = keepalive_expiry
self._socket_poll_interval_between = socket_poll_interval_between
self._expire_at: Optional[float] = None
self._state = HTTPConnectionState.NEW
self._state_lock = AsyncLock()
Expand All @@ -82,6 +79,16 @@ async def handle_async_request(self, request: Request) -> Response:
)

async with self._state_lock:
# If the HTTP connection is idle but the socket is readable, then the
# only valid state is that the socket is about to return b"", indicating
# a server-initiated disconnect.
server_disconnected = (
self._state == HTTPConnectionState.IDLE
and self._network_stream.get_extra_info("is_readable")
)
if server_disconnected:
raise ConnectionNotAvailable()

if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE):
self._request_count += 1
self._state = HTTPConnectionState.ACTIVE
Expand Down Expand Up @@ -287,29 +294,7 @@ def is_available(self) -> bool:

def has_expired(self) -> bool:
now = time.monotonic()
keepalive_expired = self._expire_at is not None and now > self._expire_at
if keepalive_expired:
return True

# If the HTTP connection is idle but the socket is readable, then the
# only valid state is that the socket is about to return b"", indicating
# a server-initiated disconnect.
# Checking the readable status is relatively expensive so check it at a lower frequency.
if (now - self._network_stream_used_at) > self._socket_poll_interval():
self._network_stream_used_at = now
server_disconnected = (
self._state == HTTPConnectionState.IDLE
and self._network_stream.get_extra_info("is_readable")
)
if server_disconnected:
return True

return False

def _socket_poll_interval(self) -> float:
# Randomize to avoid polling for all the connections at once
low, high = self._socket_poll_interval_between
return random.uniform(low, high)
return self._expire_at is not None and now > self._expire_at

def is_idle(self) -> bool:
return self._state == HTTPConnectionState.IDLE
Expand Down
37 changes: 11 additions & 26 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import enum
import logging
import random
import ssl
import time
from types import TracebackType
Expand Down Expand Up @@ -57,12 +56,10 @@ def __init__(
origin: Origin,
stream: NetworkStream,
keepalive_expiry: Optional[float] = None,
socket_poll_interval_between: Tuple[float, float] = (1, 3),
) -> None:
self._origin = origin
self._network_stream = stream
self._keepalive_expiry = keepalive_expiry
self._socket_poll_interval_between = socket_poll_interval_between
self._expire_at: Optional[float] = None
self._state = HTTPConnectionState.NEW
self._state_lock = Lock()
Expand All @@ -82,6 +79,16 @@ def handle_request(self, request: Request) -> Response:
)

with self._state_lock:
# If the HTTP connection is idle but the socket is readable, then the
# only valid state is that the socket is about to return b"", indicating
# a server-initiated disconnect.
server_disconnected = (
self._state == HTTPConnectionState.IDLE
and self._network_stream.get_extra_info("is_readable")
)
if server_disconnected:
raise ConnectionNotAvailable()

if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE):
self._request_count += 1
self._state = HTTPConnectionState.ACTIVE
Expand Down Expand Up @@ -287,29 +294,7 @@ def is_available(self) -> bool:

def has_expired(self) -> bool:
now = time.monotonic()
keepalive_expired = self._expire_at is not None and now > self._expire_at
if keepalive_expired:
return True

# If the HTTP connection is idle but the socket is readable, then the
# only valid state is that the socket is about to return b"", indicating
# a server-initiated disconnect.
# Checking the readable status is relatively expensive so check it at a lower frequency.
if (now - self._network_stream_used_at) > self._socket_poll_interval():
self._network_stream_used_at = now
server_disconnected = (
self._state == HTTPConnectionState.IDLE
and self._network_stream.get_extra_info("is_readable")
)
if server_disconnected:
return True

return False

def _socket_poll_interval(self) -> float:
# Randomize to avoid polling for all the connections at once
low, high = self._socket_poll_interval_between
return random.uniform(low, high)
return self._expire_at is not None and now > self._expire_at

def is_idle(self) -> bool:
return self._state == HTTPConnectionState.IDLE
Expand Down
102 changes: 5 additions & 97 deletions tests/_async/test_http11.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any, List

import pytest

import httpcore
Expand All @@ -18,10 +16,7 @@ async def test_http11_connection():
]
)
async with httpcore.AsyncHTTP11Connection(
origin=origin,
stream=stream,
keepalive_expiry=5.0,
socket_poll_interval_between=(0, 0),
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
response = await conn.request("GET", "https://example.com/")
assert response.status == 200
Expand Down Expand Up @@ -53,9 +48,7 @@ async def test_http11_connection_unread_response():
b"Hello, world!",
]
)
async with httpcore.AsyncHTTP11Connection(
origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
) as conn:
async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn:
async with conn.stream("GET", "https://example.com/") as response:
assert response.status == 200

Expand All @@ -77,9 +70,7 @@ async def test_http11_connection_with_remote_protocol_error():
"""
origin = httpcore.Origin(b"https", b"example.com", 443)
stream = httpcore.AsyncMockStream([b"Wait, this isn't valid HTTP!", b""])
async with httpcore.AsyncHTTP11Connection(
origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
) as conn:
async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn:
with pytest.raises(httpcore.RemoteProtocolError):
await conn.request("GET", "https://example.com/")

Expand Down Expand Up @@ -108,9 +99,7 @@ async def test_http11_connection_with_incomplete_response():
b"Hello, wor",
]
)
async with httpcore.AsyncHTTP11Connection(
origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
) as conn:
async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn:
with pytest.raises(httpcore.RemoteProtocolError):
await conn.request("GET", "https://example.com/")

Expand Down Expand Up @@ -140,9 +129,7 @@ async def test_http11_connection_with_local_protocol_error():
b"Hello, world!",
]
)
async with httpcore.AsyncHTTP11Connection(
origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
) as conn:
async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn:
with pytest.raises(httpcore.LocalProtocolError) as exc_info:
await conn.request("GET", "https://example.com/", headers={"Host": "\0"})

Expand All @@ -158,85 +145,6 @@ async def test_http11_connection_with_local_protocol_error():
)


@pytest.mark.anyio
async def test_http11_has_expired_checks_readable_status():
class AsyncMockStreamReadable(httpcore.AsyncMockStream):
def __init__(self, buffer: List[bytes]) -> None:
super().__init__(buffer)
self.is_readable = False
self.checks = 0

def get_extra_info(self, info: str) -> Any:
if info == "is_readable":
self.checks += 1
return self.is_readable
return super().get_extra_info(info) # pragma: nocover

origin = httpcore.Origin(b"https", b"example.com", 443)
stream = AsyncMockStreamReadable(
[
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
b"\r\n",
b"Hello, world!",
]
)
async with httpcore.AsyncHTTP11Connection(
origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
) as conn:
response = await conn.request("GET", "https://example.com/")
assert response.status == 200

assert stream.checks == 0
assert not conn.has_expired()
stream.is_readable = True
assert conn.has_expired()
assert stream.checks == 2


@pytest.mark.anyio
@pytest.mark.parametrize("should_check", [True, False])
async def test_http11_has_expired_checks_readable_status_by_interval(
monkeypatch, should_check
):
origin = httpcore.Origin(b"https", b"example.com", 443)
stream = httpcore.AsyncMockStream(
[
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
b"\r\n",
b"Hello, world!",
]
)
async with httpcore.AsyncHTTP11Connection(
origin=origin,
stream=stream,
keepalive_expiry=5.0,
socket_poll_interval_between=(0, 0) if should_check else (999, 999),
) as conn:
orig = conn._network_stream.get_extra_info
calls = []

def patch_get_extra_info(attr_name: str) -> Any:
calls.append(attr_name)
return orig(attr_name)

monkeypatch.setattr(
conn._network_stream, "get_extra_info", patch_get_extra_info
)

response = await conn.request("GET", "https://example.com/")
assert response.status == 200

assert "is_readable" not in calls
assert not conn.has_expired()
assert (
("is_readable" in calls) if should_check else ("is_readable" not in calls)
)


@pytest.mark.anyio
async def test_http11_connection_handles_one_active_request():
"""
Expand Down
Loading

0 comments on commit 8b99090

Please sign in to comment.