Skip to content

Commit

Permalink
Add middleware per Route/WebSocketRoute (#2349)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Dec 1, 2023
1 parent 1fd4b20 commit 164b350
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
10 changes: 10 additions & 0 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __init__(
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
Expand All @@ -236,6 +237,10 @@ def __init__(
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)

if methods is None:
self.methods = None
else:
Expand Down Expand Up @@ -309,6 +314,7 @@ def __init__(
endpoint: typing.Callable[..., typing.Any],
*,
name: typing.Optional[str] = None,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
Expand All @@ -325,6 +331,10 @@ def __init__(
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)

self.path_regex, self.path_format, self.param_convertors = compile_path(path)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
Expand Down
2 changes: 1 addition & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def delete( # type: ignore[override]

def websocket_connect(
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
) -> typing.Any:
) -> "WebSocketTestSession":
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
Expand Down
53 changes: 52 additions & 1 deletion tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,18 @@ def assert_middleware_header_route(request: Request) -> Response:
return Response()


route_with_middleware = Starlette(
routes=[
Route(
"/http",
endpoint=assert_middleware_header_route,
methods=["GET"],
middleware=[Middleware(AddHeadersMiddleware)],
),
Route("/home", homepage),
]
)

mounted_routes_with_middleware = Starlette(
routes=[
Mount(
Expand Down Expand Up @@ -960,9 +972,10 @@ def assert_middleware_header_route(request: Request) -> Response:
[
mounted_routes_with_middleware,
mounted_app_with_middleware,
route_with_middleware,
],
)
def test_mount_middleware(
def test_base_route_middleware(
test_client_factory: typing.Callable[..., TestClient],
app: Starlette,
) -> None:
Expand Down Expand Up @@ -1076,6 +1089,44 @@ async def modified_send(msg: Message) -> None:
assert "X-Mounted" in resp.headers


def test_websocket_route_middleware(
test_client_factory: typing.Callable[..., TestClient]
):
async def websocket_endpoint(session: WebSocket):
await session.accept()
await session.send_text("Hello, world!")
await session.close()

class WebsocketMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async def modified_send(msg: Message) -> None:
if msg["type"] == "websocket.accept":
msg["headers"].append((b"X-Test", b"Set by middleware"))
await send(msg)

await self.app(scope, receive, modified_send)

app = Starlette(
routes=[
WebSocketRoute(
"/ws",
endpoint=websocket_endpoint,
middleware=[Middleware(WebsocketMiddleware)],
)
]
)

client = test_client_factory(app)

with client.websocket_connect("/ws") as websocket:
text = websocket.receive_text()
assert text == "Hello, world!"
assert websocket.extra_headers == [(b"X-Test", b"Set by middleware")]


def test_route_repr() -> None:
route = Route("/welcome", endpoint=homepage)
assert (
Expand Down

0 comments on commit 164b350

Please sign in to comment.