Skip to content

Commit

Permalink
Fix issue with middleware args passing (#2752)
Browse files Browse the repository at this point in the history
  • Loading branch information
uriyyo authored Nov 14, 2024
1 parent c2e3a39 commit 427a8dc
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 15 deletions.
6 changes: 3 additions & 3 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing_extensions import ParamSpec

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand Down Expand Up @@ -96,7 +96,7 @@ def build_middleware_stack(self) -> ASGIApp:

app = self.router
for cls, args, kwargs in reversed(middleware):
app = cls(app=app, *args, **kwargs)
app = cls(app, *args, **kwargs)
return app

@property
Expand All @@ -123,7 +123,7 @@ def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:

def add_middleware(
self,
middleware_class: type[_MiddlewareClass[P]],
middleware_class: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
Expand Down
13 changes: 6 additions & 7 deletions starlette/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
else: # pragma: no cover
from typing_extensions import ParamSpec

from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp

P = ParamSpec("P")


class _MiddlewareClass(Protocol[P]):
def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ... # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ... # pragma: no cover
class _MiddlewareFactory(Protocol[P]):
def __call__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover


class Middleware:
def __init__(
self,
cls: type[_MiddlewareClass[P]],
cls: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
Expand All @@ -38,5 +36,6 @@ def __repr__(self) -> str:
class_name = self.__class__.__name__
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
name = getattr(self.cls, "__name__", "")
args_repr = ", ".join([name] + args_strings + option_strings)
return f"{class_name}({args_repr})"
6 changes: 3 additions & 3 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __init__(

if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.app = cls(self.app, *args, **kwargs)

if methods is None:
self.methods = None
Expand Down Expand Up @@ -328,7 +328,7 @@ def __init__(

if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.app = cls(self.app, *args, **kwargs)

self.path_regex, self.path_format, self.param_convertors = compile_path(path)

Expand Down Expand Up @@ -388,7 +388,7 @@ def __init__(
self.app = self._base_app
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.app = cls(self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")

Expand Down
4 changes: 2 additions & 2 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import ClientDisconnect, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
Expand Down Expand Up @@ -232,7 +232,7 @@ async def dispatch(
)
def test_contextvars(
test_client_factory: TestClientFactory,
middleware_cls: type[_MiddlewareClass[Any]],
middleware_cls: _MiddlewareFactory[Any],
) -> None:
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
Expand Down
44 changes: 44 additions & 0 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
from contextlib import asynccontextmanager
from pathlib import Path
Expand Down Expand Up @@ -533,6 +535,48 @@ def get_app() -> ASGIApp:
assert SimpleInitializableMiddleware.counter == 2


def test_middleware_args(test_client_factory: TestClientFactory) -> None:
calls: list[str] = []

class MiddlewareWithArgs:
def __init__(self, app: ASGIApp, arg: str) -> None:
self.app = app
self.arg = arg

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
calls.append(self.arg)
await self.app(scope, receive, send)

app = Starlette()
app.add_middleware(MiddlewareWithArgs, "foo")
app.add_middleware(MiddlewareWithArgs, "bar")

with test_client_factory(app):
pass

assert calls == ["bar", "foo"]


def test_middleware_factory(test_client_factory: TestClientFactory) -> None:
calls: list[str] = []

def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp:
async def _app(scope: Scope, receive: Receive, send: Send) -> None:
calls.append(arg)
await app(scope, receive, send)

return _app

app = Starlette()
app.add_middleware(_middleware_factory, arg="foo")
app.add_middleware(_middleware_factory, arg="bar")

with test_client_factory(app):
pass

assert calls == ["bar", "foo"]


def test_lifespan_app_subclass() -> None:
# This test exists to make sure that subclasses of Starlette
# (like FastAPI) are compatible with the types hints for Lifespan
Expand Down

0 comments on commit 427a8dc

Please sign in to comment.