diff --git a/connexion/middleware/main.py b/connexion/middleware/main.py index b0e1a682b..4b5d84225 100644 --- a/connexion/middleware/main.py +++ b/connexion/middleware/main.py @@ -308,6 +308,9 @@ def _build_middleware_stack(self) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]: app = middleware(app) # type: ignore apps.append(app) + # We sort the APIs by base path so that the most specific APIs are registered first. + # This is due to the way Starlette matches routes. + self.apis = utils.sort_apis_by_basepath(self.apis) for app in apps: if isinstance(app, SpecMiddleware): for api in self.apis: diff --git a/connexion/utils.py b/connexion/utils.py index 2bb8c0cb9..550b12dd2 100644 --- a/connexion/utils.py +++ b/connexion/utils.py @@ -12,9 +12,13 @@ import typing as t import yaml +from starlette.routing import compile_path from connexion.exceptions import TypeValidationError +if t.TYPE_CHECKING: + from connexion.middleware.main import API + def boolean(s): """ @@ -423,3 +427,63 @@ def inspect_function_arguments(function: t.Callable) -> t.Tuple[t.List[str], boo ] has_kwargs = any(p.kind == p.VAR_KEYWORD for p in parameters.values()) return list(bound_arguments), has_kwargs + + +T = t.TypeVar("T") + + +@t.overload +def sort_routes(routes: t.List[str], *, key: None = None) -> t.List[str]: + ... + + +@t.overload +def sort_routes(routes: t.List[T], *, key: t.Callable[[T], str]) -> t.List[T]: + ... + + +def sort_routes(routes, *, key=None): + """Sorts a list of routes from most specific to least specific. + + See Starlette routing documentation and implementation as this function + is aimed to sort according to that logic. + - https://www.starlette.io/routing/#route-priority + + The only difference is that a `path` component is appended to each route + such that `/` is less specific than `/basepath` while they are technically + not comparable. + This is because it is also done by the `Mount` class internally: + https://github.com/encode/starlette/blob/1c1043ca0ab7126419948b27f9d0a78270fd74e6/starlette/routing.py#L388 + + For example, from most to least specific: + - /users/me + - /users/{username}/projects/{project} + - /users/{username} + + :param routes: List of routes to sort + :param key: Function to extract the path from a route if it is not a string + + :return: List of routes sorted from most specific to least specific + """ + + class SortableRoute: + def __init__(self, path: str) -> None: + self.path = path.rstrip("/") + if not self.path.endswith("/{path:path}"): + self.path += "/{path:path}" + self.path_regex, _, _ = compile_path(self.path) + + def __lt__(self, other: "SortableRoute") -> bool: + return bool(other.path_regex.match(self.path)) + + return sorted(routes, key=lambda r: SortableRoute(key(r) if key else r)) + + +def sort_apis_by_basepath(apis: t.List["API"]) -> t.List["API"]: + """Sorts a list of APIs by basepath. + + :param apis: List of APIs to sort + + :return: List of APIs sorted by basepath + """ + return sort_routes(apis, key=lambda api: api.base_path or "/") diff --git a/pyproject.toml b/pyproject.toml index 70030476e..59ff224d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,3 +97,10 @@ asyncio_mode = "auto" [tool.isort] profile = "black" + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if t.TYPE_CHECKING:", + "@t.overload", +] diff --git a/tests/test_utils.py b/tests/test_utils.py index 0ea0635e3..adfa1b115 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -72,3 +72,74 @@ def test_is_json_mimetype(): "application/vnd.com.myEntreprise.v6+json; charset=UTF-8" ) assert not utils.is_json_mimetype("text/html") + + +def test_sort_routes(): + routes = ["/users/me", "/users/{username}"] + expected = ["/users/me", "/users/{username}"] + assert utils.sort_routes(routes) == expected + + routes = ["/{path:path}", "/basepath/{path:path}"] + expected = ["/basepath/{path:path}", "/{path:path}"] + assert utils.sort_routes(routes) == expected + + routes = ["/", "/basepath"] + expected = ["/basepath", "/"] + assert utils.sort_routes(routes) == expected + + routes = ["/basepath/{path:path}", "/basepath/v2/{path:path}"] + expected = ["/basepath/v2/{path:path}", "/basepath/{path:path}"] + assert utils.sort_routes(routes) == expected + + routes = ["/basepath", "/basepath/v2"] + expected = ["/basepath/v2", "/basepath"] + assert utils.sort_routes(routes) == expected + + routes = ["/users/{username}", "/users/me"] + expected = ["/users/me", "/users/{username}"] + assert utils.sort_routes(routes) == expected + + routes = [ + "/users/{username}", + "/users/me", + "/users/{username}/items", + "/users/{username}/items/{item}", + ] + expected = [ + "/users/me", + "/users/{username}/items/{item}", + "/users/{username}/items", + "/users/{username}", + ] + assert utils.sort_routes(routes) == expected + + routes = [ + "/users/{username}", + "/users/me", + "/users/{username}/items/{item}", + "/users/{username}/items/special", + ] + expected = [ + "/users/me", + "/users/{username}/items/special", + "/users/{username}/items/{item}", + "/users/{username}", + ] + assert utils.sort_routes(routes) == expected + + +def test_sort_apis_by_basepath(): + api1 = MagicMock(base_path="/") + api2 = MagicMock(base_path="/basepath") + assert utils.sort_apis_by_basepath([api1, api2]) == [api2, api1] + + api3 = MagicMock(base_path="/basepath/v2") + assert utils.sort_apis_by_basepath([api1, api2, api3]) == [api3, api2, api1] + + api4 = MagicMock(base_path="/healthz") + assert utils.sort_apis_by_basepath([api1, api2, api3, api4]) == [ + api3, + api2, + api4, + api1, + ]