Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise ClientDisconnected on send() when client disconnected #2220

Merged
merged 6 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 40 additions & 56 deletions tests/protocols/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from uvicorn.lifespan.on import LifespanOn
from uvicorn.main import ServerState
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.protocols.utils import ClientDisconnected

try:
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
Expand Down Expand Up @@ -369,9 +370,7 @@ async def test_close(http_protocol_cls: HTTPProtocol):


@pytest.mark.anyio
async def test_chunked_encoding(
http_protocol_cls: HTTPProtocol,
):
async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
app = Response(
b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}
)
Expand All @@ -385,9 +384,7 @@ async def test_chunked_encoding(


@pytest.mark.anyio
async def test_chunked_encoding_empty_body(
http_protocol_cls: HTTPProtocol,
):
async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
app = Response(
b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}
)
Expand Down Expand Up @@ -416,9 +413,7 @@ async def test_chunked_encoding_head_request(


@pytest.mark.anyio
async def test_pipelined_requests(
http_protocol_cls: HTTPProtocol,
):
async def test_pipelined_requests(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -440,9 +435,7 @@ async def test_pipelined_requests(


@pytest.mark.anyio
async def test_undersized_request(
http_protocol_cls: HTTPProtocol,
):
async def test_undersized_request(http_protocol_cls: HTTPProtocol):
app = Response(b"xxx", headers={"content-length": "10"})

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -452,9 +445,7 @@ async def test_undersized_request(


@pytest.mark.anyio
async def test_oversized_request(
http_protocol_cls: HTTPProtocol,
):
async def test_oversized_request(http_protocol_cls: HTTPProtocol):
app = Response(b"xxx" * 20, headers={"content-length": "10"})

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -464,9 +455,7 @@ async def test_oversized_request(


@pytest.mark.anyio
async def test_large_post_request(
http_protocol_cls: HTTPProtocol,
):
async def test_large_post_request(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -486,9 +475,7 @@ async def test_invalid_http(http_protocol_cls: HTTPProtocol):


@pytest.mark.anyio
async def test_app_exception(
http_protocol_cls: HTTPProtocol,
):
async def test_app_exception(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
raise Exception()

Expand All @@ -500,9 +487,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_exception_during_response(
http_protocol_cls: HTTPProtocol,
):
async def test_exception_during_response(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b"1", "more_body": True})
Expand All @@ -516,9 +501,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_no_response_returned(
http_protocol_cls: HTTPProtocol,
):
async def test_no_response_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
...

Expand All @@ -530,9 +513,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_partial_response_returned(
http_protocol_cls: HTTPProtocol,
):
async def test_partial_response_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})

Expand All @@ -544,9 +525,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_duplicate_start_message(
http_protocol_cls: HTTPProtocol,
):
async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.start", "status": 200})
Expand All @@ -559,9 +538,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_missing_start_message(
http_protocol_cls: HTTPProtocol,
):
async def test_missing_start_message(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.body", "body": b""})

Expand All @@ -573,9 +550,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_message_after_body_complete(
http_protocol_cls: HTTPProtocol,
):
async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
Expand All @@ -589,9 +564,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_value_returned(
http_protocol_cls: HTTPProtocol,
):
async def test_value_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
Expand All @@ -605,9 +578,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_early_disconnect(
http_protocol_cls: HTTPProtocol,
):
async def test_early_disconnect(http_protocol_cls: HTTPProtocol):
got_disconnect_event = False

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
Expand All @@ -629,9 +600,26 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_early_response(
http_protocol_cls: HTTPProtocol,
):
async def test_disconnect_on_send(http_protocol_cls: HTTPProtocol) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test added.

got_disconnected = False

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
try:
await send({"type": "http.response.start", "status": 200})
except ClientDisconnected:
nonlocal got_disconnected
got_disconnected = True

protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
protocol.connection_lost(None)
await protocol.loop.run_one()
assert got_disconnected


@pytest.mark.anyio
async def test_early_response(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -643,9 +631,7 @@ async def test_early_response(


@pytest.mark.anyio
async def test_read_after_response(
http_protocol_cls: HTTPProtocol,
):
async def test_read_after_response(http_protocol_cls: HTTPProtocol):
message_after_response = None

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
Expand All @@ -663,9 +649,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_http10_request(
http_protocol_cls: HTTPProtocol,
):
async def test_http10_request(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
content = "Version: %s" % scope["http_version"]
Expand Down Expand Up @@ -876,8 +860,8 @@ async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable):
@pytest.mark.parametrize(
"asgi2or3_app, expected_scopes",
[
(asgi3app, {"version": "3.0", "spec_version": "2.3"}),
(asgi2app, {"version": "2.0", "spec_version": "2.3"}),
(asgi3app, {"version": "3.0", "spec_version": "2.4"}),
(asgi2app, {"version": "2.0", "spec_version": "2.4"}),
],
)
async def test_scopes(
Expand Down
17 changes: 10 additions & 7 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
service_unavailable,
)
from uvicorn.protocols.utils import (
ClientDisconnected,
get_client_addr,
get_local_addr,
get_path_with_query_string,
Expand Down Expand Up @@ -205,7 +206,7 @@ def handle_events(self) -> None:
"type": "http",
"asgi": {
"version": self.config.asgi_version,
"spec_version": "2.3",
"spec_version": "2.4",
},
"http_version": event.http_version.decode("ascii"),
"server": self.server,
Expand Down Expand Up @@ -412,6 +413,8 @@ async def run_asgi(self, app: "ASGI3Application") -> None:
result = await app( # type: ignore[func-returns-value]
self.scope, self.receive, self.send
)
except ClientDisconnected:
pass
except BaseException as exc:
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
Expand All @@ -436,7 +439,7 @@ async def run_asgi(self, app: "ASGI3Application") -> None:
self.on_response = lambda: None

async def send_500_response(self) -> None:
response_start_event: "HTTPResponseStartEvent" = {
response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": 500,
"headers": [
Expand All @@ -445,22 +448,22 @@ async def send_500_response(self) -> None:
],
}
await self.send(response_start_event)
response_body_event: "HTTPResponseBodyEvent" = {
response_body_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
}
await self.send(response_body_event)

# ASGI interface
async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]

if self.flow.write_paused and not self.disconnected:
await self.flow.drain()

if self.disconnected:
return
raise ClientDisconnected

if not self.response_started:
# Sending response status line and headers
Expand Down Expand Up @@ -527,7 +530,7 @@ async def send(self, message: "ASGISendEvent") -> None:
self.transport.close()
self.on_response()

async def receive(self) -> "ASGIReceiveEvent":
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
headers: list[tuple[str, str]] = []
event = h11.InformationalResponse(
Expand All @@ -545,7 +548,7 @@ async def receive(self) -> "ASGIReceiveEvent":
if self.disconnected or self.response_complete:
return {"type": "http.disconnect"}

message: "HTTPRequestEvent" = {
message: HTTPRequestEvent = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,
Expand Down
Loading