From 4631efb6386ba6b11ba3706a6dca747f86ae6c6c Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Mon, 2 Sep 2024 18:26:54 +0200 Subject: [PATCH] Add support for ASGI `pathsend` extension in `BaseHTTPMiddleware` --- docs/middleware.md | 1 - starlette/middleware/base.py | 18 ++++++++++-- tests/middleware/test_base.py | 52 ++++++++++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/docs/middleware.md b/docs/middleware.md index 23f0eeb84..9e5601819 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -264,7 +264,6 @@ around explicitly, rather than mutating the middleware instance. Currently, the `BaseHTTPMiddleware` has some known limitations: - Using `BaseHTTPMiddleware` will prevent changes to [`contextlib.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/encode/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior). -- Using `BaseHTTPMiddleware` will prevent [ASGI pathsend extension](https://asgi.readthedocs.io/en/latest/extensions.html#path-send) to work properly. Thus, if you run your Starlette application with a server implementing this extension, routes returning [FileResponse](responses.md#fileresponse) should avoid the usage of this middleware. To overcome these limitations, use [pure ASGI middleware](#pure-asgi-middleware), as shown below. diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 2ac6f7f7f..57e652436 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -7,11 +7,13 @@ from starlette._utils import collapse_excgroups from starlette.requests import ClientDisconnect, Request -from starlette.responses import AsyncContentStream, Response +from starlette.responses import Response from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]] +BodyStreamGenerator = typing.AsyncGenerator[typing.Union[bytes, typing.MutableMapping[str, typing.Any]], None] +AsyncContentStream = typing.AsyncIterable[typing.Union[str, bytes, memoryview, typing.MutableMapping[str, typing.Any]]] T = typing.TypeVar("T") @@ -165,9 +167,12 @@ async def coro() -> None: assert message["type"] == "http.response.start" - async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async def body_stream() -> BodyStreamGenerator: async with recv_stream: async for message in recv_stream: + if message["type"] == "http.response.pathsend": + yield message + break assert message["type"] == "http.response.body" body = message.get("body", b"") if body: @@ -218,7 +223,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: } ) + should_close_body = True async for chunk in self.body_iterator: + if isinstance(chunk, dict): + # We got an ASGI message which is not response body (eg: pathsend) + should_close_body = False + await send(chunk) + continue await send({"type": "http.response.body", "body": chunk, "more_body": True}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) + if should_close_body: + await send({"type": "http.response.body", "body": b"", "more_body": False}) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 225038650..a2bd58b60 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -2,6 +2,7 @@ import contextvars from contextlib import AsyncExitStack +from pathlib import Path from typing import ( Any, AsyncGenerator, @@ -18,7 +19,7 @@ from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import ClientDisconnect, Request -from starlette.responses import PlainTextResponse, Response, StreamingResponse +from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -1132,3 +1133,52 @@ async def send(message: Message) -> None: {"type": "http.response.body", "body": b"good!", "more_body": True}, {"type": "http.response.body", "body": b"", "more_body": False}, ] + + +@pytest.mark.anyio +async def test_asgi_pathsend_events(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + request_body_sent = False + response_complete = anyio.Event() + events: list[Message] = [] + + async def endpoint_with_pathsend(_: Request) -> FileResponse: + return FileResponse(path) + + async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response: + return await call_next(request) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], + routes=[Route("/", endpoint_with_pathsend)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: + events.append(message) + if message["type"] == "http.response.pathsend": + response_complete.set() + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend"