Skip to content

Commit

Permalink
Re-work on cleanup design.
Browse files Browse the repository at this point in the history
  • Loading branch information
T-256 committed Sep 18, 2023
1 parent 6071ea7 commit 6105e0c
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 236 deletions.
44 changes: 20 additions & 24 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,31 +72,27 @@ async def handle_async_request(self, request: Request) -> Response:

async with self._request_lock:
if self._connection is None:
try:
stream = await self._connect(request)

ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
stream = await self._connect(request)

ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import AsyncHTTP2Connection

self._connection = AsyncHTTP2Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = AsyncHTTP11Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import AsyncHTTP2Connection

self._connection = AsyncHTTP2Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = AsyncHTTP11Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
except Exception as exc:
self._connect_failed = True
raise exc
elif not self._connection.is_available():
raise ConnectionNotAvailable()

Expand Down
17 changes: 2 additions & 15 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,8 @@ async def handle_async_request(self, request: Request) -> Response:
while True:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
try:
connection = await status.wait_for_connection(timeout=timeout)
except BaseException as exc:
# If we timeout here, or if the task is cancelled, then make
# sure to remove the request from the queue before bubbling
# up the exception.
async with self._pool_lock:
# Ensure only remove when task exists.
if status in self._requests:
self._requests.remove(status)
raise exc

connection = await status.wait_for_connection(timeout=timeout)

try:
response = await connection.handle_async_request(request)
Expand All @@ -256,10 +247,6 @@ async def handle_async_request(self, request: Request) -> Response:
# status so that the request becomes queued again.
status.unset_connection()
await self._attempt_to_acquire_connection(status)
except BaseException as exc:
with AsyncShieldCancellation():
await self.response_closed(status)
raise exc
else:
break

Expand Down
107 changes: 52 additions & 55 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ def __init__(
max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
)

async def _cleanup(self, request: Request, exc: BaseException) -> BaseException:
exc = super()._cleanup(request, exc)

with AsyncShieldCancellation():
async with Trace("response_closed", logger, request) as trace:
await self._response_closed()

return exc

async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
raise RuntimeError(
Expand All @@ -83,54 +92,48 @@ async def handle_async_request(self, request: Request) -> Response:
else:
raise ConnectionNotAvailable()

kwargs = {"request": request}
try:
kwargs = {"request": request}
try:
async with Trace(
"send_request_headers", logger, request, kwargs
) as trace:
await self._send_request_headers(**kwargs)
async with Trace("send_request_body", logger, request, kwargs) as trace:
await self._send_request_body(**kwargs)
except WriteError:
# If we get a write error while we're writing the request,
# then we supress this error and move on to attempting to
# read the response. Servers can sometimes close the request
# pre-emptively and then respond with a well formed HTTP
# error response.
pass

async with Trace(
"receive_response_headers", logger, request, kwargs
"send_request_headers", logger, request, kwargs
) as trace:
(
http_version,
status,
reason_phrase,
headers,
) = await self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
status,
reason_phrase,
headers,
)

return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
},
await self._send_request_headers(**kwargs)
async with Trace("send_request_body", logger, request, kwargs) as trace:
await self._send_request_body(**kwargs)
except WriteError:
# If we get a write error while we're writing the request,
# then we supress this error and move on to attempting to
# read the response. Servers can sometimes close the request
# pre-emptively and then respond with a well formed HTTP
# error response.
pass

async with Trace(
"receive_response_headers", logger, request, kwargs
) as trace:
(
http_version,
status,
reason_phrase,
headers,
) = await self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
status,
reason_phrase,
headers,
)
except BaseException as exc:
with AsyncShieldCancellation():
async with Trace("response_closed", logger, request) as trace:
await self._response_closed()
raise exc

return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
},
)

# Sending the request...

Expand Down Expand Up @@ -324,17 +327,11 @@ def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:

async def __aiter__(self) -> AsyncIterator[bytes]:
kwargs = {"request": self._request}
try:
async with Trace("receive_response_body", logger, self._request, kwargs):
async for chunk in self._connection._receive_response_body(**kwargs):
yield chunk
except BaseException as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with AsyncShieldCancellation():
await self.aclose()
raise exc

async with Trace("receive_response_body", logger, self._request, kwargs):
async for chunk in self._connection._receive_response_body(**kwargs):
yield chunk


async def aclose(self) -> None:
if not self._closed:
Expand Down
133 changes: 64 additions & 69 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,37 @@ def __init__(
self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None

async def _cleanup(self, request: Request, exc: BaseException) -> BaseException:
exc = super()._cleanup(request, exc)

if self._sent_connection_init:
with AsyncShieldCancellation():
await self.aclose()
return exc

with AsyncShieldCancellation():
kwargs = {"stream_id": stream_id}
async with Trace("response_closed", logger, request, kwargs):
await self._response_closed(stream_id=stream_id)

if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
# closed frame has been seen by the state machine.
#
# This happens when one stream is reading, and encounters
# a GOAWAY event. Other flows of control may then raise
# a protocol error at any point they interact with the 'h2_state'.
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_terminated: # pragma: nocover
return RemoteProtocolError(self._connection_terminated)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
return LocalProtocolError(exc) # pragma: nocover

return exc

async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
# This cannot occur in normal operation, since the connection pool
Expand All @@ -103,14 +134,9 @@ async def handle_async_request(self, request: Request) -> Response:

async with self._init_lock:
if not self._sent_connection_init:
try:
kwargs = {"request": request}
async with Trace("send_connection_init", logger, request, kwargs):
await self._send_connection_init(**kwargs)
except BaseException as exc:
with AsyncShieldCancellation():
await self.aclose()
raise exc
kwargs = {"request": request}
async with Trace("send_connection_init", logger, request, kwargs):
await self._send_connection_init(**kwargs)

self._sent_connection_init = True

Expand All @@ -136,53 +162,29 @@ async def handle_async_request(self, request: Request) -> Response:
self._request_count -= 1
raise ConnectionNotAvailable()

try:
kwargs = {"request": request, "stream_id": stream_id}
async with Trace("send_request_headers", logger, request, kwargs):
await self._send_request_headers(request=request, stream_id=stream_id)
async with Trace("send_request_body", logger, request, kwargs):
await self._send_request_body(request=request, stream_id=stream_id)
async with Trace(
"receive_response_headers", logger, request, kwargs
) as trace:
status, headers = await self._receive_response(
request=request, stream_id=stream_id
)
trace.return_value = (status, headers)

return Response(
status=status,
headers=headers,
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={
"http_version": b"HTTP/2",
"network_stream": self._network_stream,
"stream_id": stream_id,
},
kwargs = {"request": request, "stream_id": stream_id}
async with Trace("send_request_headers", logger, request, kwargs):
await self._send_request_headers(request=request, stream_id=stream_id)
async with Trace("send_request_body", logger, request, kwargs):
await self._send_request_body(request=request, stream_id=stream_id)
async with Trace(
"receive_response_headers", logger, request, kwargs
) as trace:
status, headers = await self._receive_response(
request=request, stream_id=stream_id
)
except BaseException as exc: # noqa: PIE786
with AsyncShieldCancellation():
kwargs = {"stream_id": stream_id}
async with Trace("response_closed", logger, request, kwargs):
await self._response_closed(stream_id=stream_id)

if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
# closed frame has been seen by the state machine.
#
# This happens when one stream is reading, and encounters
# a GOAWAY event. Other flows of control may then raise
# a protocol error at any point they interact with the 'h2_state'.
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_terminated: # pragma: nocover
raise RemoteProtocolError(self._connection_terminated)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
raise LocalProtocolError(exc) # pragma: nocover

raise exc
trace.return_value = (status, headers)

return Response(
status=status,
headers=headers,
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={
"http_version": b"HTTP/2",
"network_stream": self._network_stream,
"stream_id": stream_id,
},
)

async def _send_connection_init(self, request: Request) -> None:
"""
Expand Down Expand Up @@ -438,7 +440,7 @@ async def _read_incoming_data(
data = await self._network_stream.read(self.READ_NUM_BYTES, timeout)
if data == b"":
raise RemoteProtocolError("Server disconnected")
except Exception as exc:
except BaseException as exc:
# If we get a network error we should:
#
# 1. Save the exception and just raise it immediately on any future reads.
Expand Down Expand Up @@ -467,7 +469,7 @@ async def _write_outgoing_data(self, request: Request) -> None:

try:
await self._network_stream.write(data_to_send, timeout)
except Exception as exc: # pragma: nocover
except BaseException as exc: # pragma: nocover
# If we get a network error we should:
#
# 1. Save the exception and just raise it immediately on any future write.
Expand Down Expand Up @@ -567,19 +569,12 @@ def __init__(

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
kwargs = {"request": self._request, "stream_id": self._stream_id}
try:
async with Trace("receive_response_body", logger, self._request, kwargs):
async for chunk in self._connection._receive_response_body(
request=self._request, stream_id=self._stream_id
):
yield chunk
except BaseException as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with AsyncShieldCancellation():
await self.aclose()
raise exc

async with Trace("receive_response_body", logger, self._request, kwargs):
async for chunk in self._connection._receive_response_body(
request=self._request, stream_id=self._stream_id
):
yield chunk

async def aclose(self) -> None:
if not self._closed:
Expand Down
Loading

0 comments on commit 6105e0c

Please sign in to comment.