Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CORS headers not set on exceptions #1821

Merged
merged 22 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from connexion.middleware.response_validation import ResponseValidationMiddleware
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.server_error import ServerErrorMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
Expand Down Expand Up @@ -92,6 +93,17 @@ def replace(self, **changes) -> "_Options":
class MiddlewarePosition(enum.Enum):
"""Positions to insert a middleware"""

BEFORE_EXCEPTION = ExceptionMiddleware
"""Add before the :class:`ExceptionMiddleware`. This is useful if you want your changes to
affect the way exceptions are handled, such as a custom error handler.

Be mindful that security has not yet been applied at this stage.
Additionally, the inserted middleware is positioned before the RoutingMiddleware, so you cannot
leverage any routing information yet and should implement your middleware to work globally
instead of on an operation level.

Usefull for CORS middleware which should be applied before the exception middleware.
nielsbox marked this conversation as resolved.
Show resolved Hide resolved
"""
nielsbox marked this conversation as resolved.
Show resolved Hide resolved
BEFORE_SWAGGER = SwaggerUIMiddleware
"""Add before the :class:`SwaggerUIMiddleware`. This is useful if you want your changes to
affect the Swagger UI, such as a path altering middleware that should also alter the paths
Expand Down Expand Up @@ -164,6 +176,7 @@ class ConnexionMiddleware:
provided application."""

default_middlewares = [
ServerErrorMiddleware,
ExceptionMiddleware,
SwaggerUIMiddleware,
RoutingMiddleware,
Expand Down
19 changes: 19 additions & 0 deletions connexion/middleware/server_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import logging

from starlette.middleware.errors import (
ServerErrorMiddleware as StarletteServerErrorMiddleware,
)
from starlette.types import ASGIApp, Receive, Scope, Send

logger = logging.getLogger(__name__)


class ServerErrorMiddleware(StarletteServerErrorMiddleware):
Ruwann marked this conversation as resolved.
Show resolved Hide resolved
"""Subclass of starlette ServerErrorMiddleware to change handling of Unhandled Server
exceptions to existing connexion behavior."""

def __init__(self, next_app: ASGIApp):
super().__init__(next_app)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await super().__call__(scope, receive, send)
6 changes: 3 additions & 3 deletions docs/cookbook.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing

app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down Expand Up @@ -62,7 +62,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing

app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down Expand Up @@ -96,7 +96,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing

app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down
21 changes: 21 additions & 0 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging
from functools import partial

import pytest
from connexion import ConnexionMiddleware
from connexion.middleware.server_error import ServerErrorMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.types import Receive, Scope, Send

from conftest import FIXTURES_FOLDER, OPENAPI3_SPEC, build_app_from_fixture
Expand All @@ -20,6 +24,23 @@ def simple_openapi_app(app_class):
)


@pytest.fixture(scope="session")
def cors_openapi_app(app_class):
middlewares = [*ConnexionMiddleware.default_middlewares]
cors_middleware = partial(CORSMiddleware, allow_origins=["http://localhost"])
# CORS should always be directly after ServerErrorMiddleware
server_error_idx = middlewares.index(ServerErrorMiddleware)
middlewares.insert(server_error_idx + 1, cors_middleware)
nielsbox marked this conversation as resolved.
Show resolved Hide resolved

return build_app_from_fixture(
"simple",
app_class=app_class,
spec_file=OPENAPI3_SPEC,
middlewares=middlewares,
validate_responses=True,
)


@pytest.fixture(scope="session")
def reverse_proxied_app(spec, app_class):
class ReverseProxied:
Expand Down
44 changes: 44 additions & 0 deletions tests/api/test_cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json


def test_cors_valid(cors_openapi_app):
app_client = cors_openapi_app.test_client()
origin = "http://localhost"
response = app_client.post("/v1.0/goodday/dan", data={}, headers={"Origin": origin})
assert response.status_code == 201
assert "Access-Control-Allow-Origin" in response.headers
assert origin == response.headers["Access-Control-Allow-Origin"]


def test_cors_invalid(cors_openapi_app):
app_client = cors_openapi_app.test_client()
response = app_client.options(
"/v1.0/goodday/dan",
headers={"Origin": "http://0.0.0.0", "Access-Control-Request-Method": "POST"},
)
assert response.status_code == 400
assert "Access-Control-Allow-Origin" not in response.headers


def test_cors_validation_error(cors_openapi_app):
app_client = cors_openapi_app.test_client()
origin = "http://localhost"
response = app_client.post(
"/v1.0/body-not-allowed-additional-properties",
data={},
headers={"Origin": origin},
)
assert response.status_code == 400
assert "Access-Control-Allow-Origin" in response.headers
assert origin == response.headers["Access-Control-Allow-Origin"]


def test_cors_server_error(cors_openapi_app):
app_client = cors_openapi_app.test_client()
origin = "http://localhost"
response = app_client.post(
"/v1.0/goodday/noheader", data={}, headers={"Origin": origin}
)
assert response.status_code == 500
assert "Access-Control-Allow-Origin" in response.headers
assert origin == response.headers["Access-Control-Allow-Origin"]
Loading