From 402b4cabe05f80f14e4139a3b7700f0cdec68d89 Mon Sep 17 00:00:00 2001 From: Simon Gurcke Date: Tue, 15 Aug 2023 15:54:03 +1000 Subject: [PATCH] Refactor --- starlette_apitally/__init__.py | 4 - starlette_apitally/app_info.py | 65 ------- starlette_apitally/client.py | 11 +- starlette_apitally/fastapi.py | 28 +-- starlette_apitally/keys.py | 9 +- starlette_apitally/middleware.py | 87 ---------- starlette_apitally/starlette.py | 189 ++++++++++++++++++++ tests/conftest.py | 108 ------------ tests/test_app_info.py | 25 --- tests/test_client.py | 7 +- tests/test_fastapi.py | 76 ++++++++ tests/test_keys.py | 8 + tests/test_middleware.py | 87 ---------- tests/test_starlette.py | 288 +++++++++++++++++++++++++++++++ 14 files changed, 593 insertions(+), 399 deletions(-) delete mode 100644 starlette_apitally/app_info.py delete mode 100644 starlette_apitally/middleware.py create mode 100644 starlette_apitally/starlette.py delete mode 100644 tests/test_app_info.py create mode 100644 tests/test_fastapi.py delete mode 100644 tests/test_middleware.py create mode 100644 tests/test_starlette.py diff --git a/starlette_apitally/__init__.py b/starlette_apitally/__init__.py index bc78783..6c8e6b9 100644 --- a/starlette_apitally/__init__.py +++ b/starlette_apitally/__init__.py @@ -1,5 +1 @@ -from starlette_apitally.middleware import ApitallyMiddleware - - -__all__ = ["ApitallyMiddleware"] __version__ = "0.0.0" diff --git a/starlette_apitally/app_info.py b/starlette_apitally/app_info.py deleted file mode 100644 index bf60c4a..0000000 --- a/starlette_apitally/app_info.py +++ /dev/null @@ -1,65 +0,0 @@ -import sys -from typing import Any, Dict, List, Optional - -import starlette -from httpx import HTTPStatusError -from starlette.routing import BaseRoute, Router -from starlette.schemas import EndpointInfo, SchemaGenerator -from starlette.testclient import TestClient -from starlette.types import ASGIApp - -import starlette_apitally - - -def get_app_info(app: ASGIApp, app_version: Optional[str], openapi_url: Optional[str]) -> Dict[str, Any]: - app_info: Dict[str, Any] = {} - if openapi := get_openapi(app, openapi_url): - app_info["openapi"] = openapi - elif endpoints := get_endpoint_info(app): - app_info["paths"] = [{"path": endpoint.path, "method": endpoint.http_method} for endpoint in endpoints] - app_info["versions"] = get_versions(app_version) - app_info["client"] = "starlette-apitally" - return app_info - - -def get_openapi(app: ASGIApp, openapi_url: Optional[str]) -> Optional[str]: - if not openapi_url: - return None - try: - client = TestClient(app, raise_server_exceptions=False) - response = client.get(openapi_url) - response.raise_for_status() - return response.text - except HTTPStatusError: - return None - - -def get_endpoint_info(app: ASGIApp) -> List[EndpointInfo]: - routes = get_routes(app) - schemas = SchemaGenerator({}) - return schemas.get_endpoints(routes) - - -def get_routes(app: ASGIApp) -> List[BaseRoute]: - if isinstance(app, Router): - return app.routes - elif hasattr(app, "app"): - return get_routes(app.app) - return [] - - -def get_versions(app_version: Optional[str] = None) -> Dict[str, str]: - versions = { - "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", - "starlette-apitally": starlette_apitally.__version__, - "starlette": starlette.__version__, - } - try: - import fastapi - - versions["fastapi"] = fastapi.__version__ - except (ImportError, AttributeError): - pass - if app_version: - versions["app"] = app_version - return versions diff --git a/starlette_apitally/client.py b/starlette_apitally/client.py index 0b4d0d0..ac2ed18 100644 --- a/starlette_apitally/client.py +++ b/starlette_apitally/client.py @@ -9,9 +9,7 @@ import backoff import httpx -from starlette.types import ASGIApp -from starlette_apitally.app_info import get_app_info from starlette_apitally.keys import Keys from starlette_apitally.requests import Requests @@ -59,6 +57,12 @@ def __init__(self, client_id: str, env: str, enable_keys: bool = False, send_eve self._stop_sync_loop = False self.start_sync_loop() + @classmethod + def get_instance(cls) -> ApitallyClient: + if cls._instance is None: + raise RuntimeError("Apitally client not initialized") + return cls._instance + def get_http_client(self) -> httpx.AsyncClient: base_url = f"{HUB_BASE_URL}/{HUB_VERSION}/{self.client_id}/{self.env}" return httpx.AsyncClient(base_url=base_url) @@ -83,8 +87,7 @@ async def _run_sync_loop(self) -> None: def stop_sync_loop(self) -> None: self._stop_sync_loop = True - def send_app_info(self, app: ASGIApp, app_version: Optional[str], openapi_url: Optional[str]) -> None: - app_info = get_app_info(app, app_version, openapi_url) + def send_app_info(self, app_info: Dict[str, Any]) -> None: payload = { "instance_uuid": self.instance_uuid, "message_uuid": str(uuid4()), diff --git a/starlette_apitally/fastapi.py b/starlette_apitally/fastapi.py index 8051319..d7588cd 100644 --- a/starlette_apitally/fastapi.py +++ b/starlette_apitally/fastapi.py @@ -3,11 +3,17 @@ from fastapi.exceptions import HTTPException from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.requests import Request +from fastapi.security import SecurityScopes from fastapi.security.base import SecurityBase from fastapi.security.utils import get_authorization_scheme_param from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN -from starlette_apitally.keys import Key, Keys +from starlette_apitally.client import ApitallyClient +from starlette_apitally.keys import Key +from starlette_apitally.starlette import ApitallyMiddleware + + +__all__ = ["ApitallyMiddleware", "Key", "api_key_auth"] class AuthorizationAPIKeyHeader(SecurityBase): @@ -20,7 +26,7 @@ def __init__(self, *, auto_error: bool = True): self.scheme_name = "Authorization header with ApiKey scheme" self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[Key]: + async def __call__(self, request: Request, security_scopes: SecurityScopes) -> Optional[Key]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "apikey": @@ -31,23 +37,19 @@ async def __call__(self, request: Request) -> Optional[Key]: headers={"WWW-Authenticate": "ApiKey"}, ) else: - return None - keys = self._get_keys() - key = keys.get(param) + return None # pragma: no cover + key = ApitallyClient.get_instance().keys.get(param) if key is None and self.auto_error: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid API key", ) + if key is not None and self.auto_error and not key.check_scopes(security_scopes.scopes): + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="Permission denied", + ) return key - def _get_keys(self) -> Keys: - from starlette_apitally.client import ApitallyClient - - client = ApitallyClient._instance - if client is None: - raise RuntimeError("ApitallyClient not initialized") - return client.keys - api_key_auth = AuthorizationAPIKeyHeader() diff --git a/starlette_apitally/keys.py b/starlette_apitally/keys.py index 50555ad..e30c5be 100644 --- a/starlette_apitally/keys.py +++ b/starlette_apitally/keys.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta from hashlib import scrypt from typing import Any, Dict, List, Optional, Set @@ -9,16 +9,23 @@ @dataclass(frozen=True) class Key: key_id: int + name: str = "" + scopes: List[str] = field(default_factory=list) expires_at: Optional[datetime] = None @property def is_expired(self) -> bool: return self.expires_at is not None and self.expires_at < datetime.now() + def check_scopes(self, scopes: List[str]) -> bool: + return all(scope in self.scopes for scope in scopes) + @classmethod def from_dict(cls, data: Dict[str, Any]) -> Key: return cls( key_id=data["key_id"], + name=data.get("name", ""), + scopes=data.get("scopes", []), expires_at=( datetime.now() + timedelta(seconds=data["expires_in_seconds"]) if data["expires_in_seconds"] is not None diff --git a/starlette_apitally/middleware.py b/starlette_apitally/middleware.py deleted file mode 100644 index 9a77a65..0000000 --- a/starlette_apitally/middleware.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -import re -import time -from typing import TYPE_CHECKING, Optional, Tuple -from uuid import UUID - -from starlette.background import BackgroundTask -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.routing import Match -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR -from starlette.types import ASGIApp - -from starlette_apitally.client import ApitallyClient - - -if TYPE_CHECKING: - from starlette.middleware.base import RequestResponseEndpoint - from starlette.requests import Request - from starlette.responses import Response - - -class ApitallyMiddleware(BaseHTTPMiddleware): - def __init__( - self, - app: ASGIApp, - client_id: str, - env: str = "default", - app_version: Optional[str] = None, - enable_keys: bool = False, - send_every: float = 60, - filter_unhandled_paths: bool = True, - openapi_url: Optional[str] = "/openapi.json", - ) -> None: - try: - UUID(client_id) - except ValueError: - raise ValueError(f"invalid client_id '{client_id}' (expected hexadecimal UUID format)") - if re.match(r"^[\w-]{1,32}$", env) is None: - raise ValueError(f"invalid env '{env}' (expected 1-32 alphanumeric lowercase characters and hyphens only)") - if app_version is not None and len(app_version) > 32: - raise ValueError(f"invalid app_version '{app_version}' (expected 1-32 characters)") - if send_every < 10: - raise ValueError("send_every has to be greater or equal to 10 seconds") - - self.filter_unhandled_paths = filter_unhandled_paths - self.client = ApitallyClient(client_id=client_id, env=env, enable_keys=enable_keys, send_every=send_every) - self.client.send_app_info(app=app, app_version=app_version, openapi_url=openapi_url) - super().__init__(app) - - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: - try: - start_time = time.perf_counter() - response = await call_next(request) - except BaseException as e: - self.log_request( - request=request, - status_code=HTTP_500_INTERNAL_SERVER_ERROR, - response_time=time.perf_counter() - start_time, - ) - raise e from None - else: - response.background = BackgroundTask( - self.log_request, - request=request, - status_code=response.status_code, - response_time=time.perf_counter() - start_time, - ) - return response - - def log_request(self, request: Request, status_code: int, response_time: float) -> None: - path_template, is_handled_path = self.get_path_template(request) - if is_handled_path or not self.filter_unhandled_paths: - self.client.requests.log_request( - method=request.method, - path=path_template, - status_code=status_code, - response_time=response_time, - ) - - @staticmethod - def get_path_template(request: Request) -> Tuple[str, bool]: - for route in request.app.routes: - match, _ = route.matches(request.scope) - if match == Match.FULL: - return route.path, True - return request.url.path, False diff --git a/starlette_apitally/starlette.py b/starlette_apitally/starlette.py new file mode 100644 index 0000000..330ad3b --- /dev/null +++ b/starlette_apitally/starlette.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import re +import sys +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from uuid import UUID + +import starlette +from httpx import HTTPStatusError +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + AuthenticationError, + BaseUser, +) +from starlette.background import BackgroundTask +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import HTTPConnection +from starlette.routing import BaseRoute, Match, Router +from starlette.schemas import EndpointInfo, SchemaGenerator +from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR +from starlette.testclient import TestClient +from starlette.types import ASGIApp + +import starlette_apitally +from starlette_apitally.client import ApitallyClient +from starlette_apitally.keys import Key + + +if TYPE_CHECKING: + from starlette.middleware.base import RequestResponseEndpoint + from starlette.requests import Request + from starlette.responses import Response + + +__all__ = ["ApitallyMiddleware", "ApitallyKeysBackend"] + + +class ApitallyMiddleware(BaseHTTPMiddleware): + def __init__( + self, + app: ASGIApp, + client_id: str, + env: str = "default", + app_version: Optional[str] = None, + enable_keys: bool = False, + send_every: float = 60, + filter_unhandled_paths: bool = True, + openapi_url: Optional[str] = "/openapi.json", + ) -> None: + try: + UUID(client_id) + except ValueError: + raise ValueError(f"invalid client_id '{client_id}' (expected hexadecimal UUID format)") + if re.match(r"^[\w-]{1,32}$", env) is None: + raise ValueError(f"invalid env '{env}' (expected 1-32 alphanumeric lowercase characters and hyphens only)") + if app_version is not None and len(app_version) > 32: + raise ValueError(f"invalid app_version '{app_version}' (expected 1-32 characters)") + if send_every < 10: + raise ValueError("send_every has to be greater or equal to 10 seconds") + + self.filter_unhandled_paths = filter_unhandled_paths + self.client = ApitallyClient(client_id=client_id, env=env, enable_keys=enable_keys, send_every=send_every) + self.client.send_app_info(app_info=_get_app_info(app, app_version, openapi_url)) + super().__init__(app) + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + try: + start_time = time.perf_counter() + response = await call_next(request) + except BaseException as e: + self.log_request( + request=request, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + response_time=time.perf_counter() - start_time, + ) + raise e from None + else: + response.background = BackgroundTask( + self.log_request, + request=request, + status_code=response.status_code, + response_time=time.perf_counter() - start_time, + ) + return response + + def log_request(self, request: Request, status_code: int, response_time: float) -> None: + path_template, is_handled_path = self.get_path_template(request) + if is_handled_path or not self.filter_unhandled_paths: + self.client.requests.log_request( + method=request.method, + path=path_template, + status_code=status_code, + response_time=response_time, + ) + + @staticmethod + def get_path_template(request: Request) -> Tuple[str, bool]: + for route in request.app.routes: + match, _ = route.matches(request.scope) + if match == Match.FULL: + return route.path, True + return request.url.path, False + + +class ApitallyKeysBackend(AuthenticationBackend): + async def authenticate(self, conn: HTTPConnection) -> Optional[Tuple[AuthCredentials, BaseUser]]: + if "Authorization" not in conn.headers: + return None + auth = conn.headers["Authorization"] + scheme, _, credentials = auth.partition(" ") + if scheme.lower() != "apikey": + return None + key = ApitallyClient.get_instance().keys.get(credentials) + if key is None: + raise AuthenticationError("Invalid API key") + return AuthCredentials(["authenticated"] + key.scopes), ApitallyKeyUser(key) + + +class ApitallyKeyUser(BaseUser): + def __init__(self, key: Key) -> None: + self.key = key + + @property + def is_authenticated(self) -> bool: + return True + + @property + def display_name(self) -> str: + return self.key.name + + @property + def identity(self) -> str: + return str(self.key.key_id) + + +def _get_app_info(app: ASGIApp, app_version: Optional[str], openapi_url: Optional[str]) -> Dict[str, Any]: + app_info: Dict[str, Any] = {} + if openapi := _get_openapi(app, openapi_url): + app_info["openapi"] = openapi + elif endpoints := _get_endpoint_info(app): + app_info["paths"] = [{"path": endpoint.path, "method": endpoint.http_method} for endpoint in endpoints] + app_info["versions"] = _get_versions(app_version) + app_info["client"] = "starlette-apitally" + return app_info + + +def _get_openapi(app: ASGIApp, openapi_url: Optional[str]) -> Optional[str]: + if not openapi_url: + return None + try: + client = TestClient(app, raise_server_exceptions=False) + response = client.get(openapi_url) + response.raise_for_status() + return response.text + except HTTPStatusError: + return None + + +def _get_endpoint_info(app: ASGIApp) -> List[EndpointInfo]: + routes = _get_routes(app) + schemas = SchemaGenerator({}) + return schemas.get_endpoints(routes) + + +def _get_routes(app: ASGIApp) -> List[BaseRoute]: + if isinstance(app, Router): + return app.routes + elif hasattr(app, "app"): + return _get_routes(app.app) + return [] + + +def _get_versions(app_version: Optional[str]) -> Dict[str, str]: + versions = { + "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + "starlette-apitally": starlette_apitally.__version__, + "starlette": starlette.__version__, + } + try: + import fastapi + + versions["fastapi"] = fastapi.__version__ + except (ImportError, AttributeError): # pragma: no cover + pass + if app_version: + versions["app"] = app_version + return versions diff --git a/tests/conftest.py b/tests/conftest.py index 6326f35..fd6a2ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,6 @@ -from __future__ import annotations - -import asyncio import os -from asyncio import AbstractEventLoop -from importlib.util import find_spec -from typing import TYPE_CHECKING, Iterator -from unittest.mock import MagicMock import pytest -from pytest import FixtureRequest -from pytest_mock import MockerFixture -from starlette.background import BackgroundTasks # import here to avoid pydantic error - - -if TYPE_CHECKING: - from starlette.applications import Starlette if os.getenv("PYTEST_RAISE", "0") != "0": @@ -26,97 +12,3 @@ def pytest_exception_interact(call): @pytest.hookimpl(tryfirst=True) def pytest_internalerror(excinfo): raise excinfo.value - - -@pytest.fixture(scope="session") -def client_id() -> str: - return "76b5cb91-a0a4-4ea0-a894-57d2b9fcb2c9" - - -@pytest.fixture(scope="module") -def event_loop() -> Iterator[AbstractEventLoop]: - policy = asyncio.get_event_loop_policy() - loop = policy.new_event_loop() - yield loop - loop.close() - - -@pytest.fixture( - scope="module", - params=["starlette", "fastapi"] if find_spec("fastapi") is not None else ["starlette"], -) -async def app(request: FixtureRequest, client_id: str, module_mocker: MockerFixture) -> Starlette: - module_mocker.patch("starlette_apitally.client.ApitallyClient.start_sync_loop") - module_mocker.patch("starlette_apitally.client.ApitallyClient.send_app_info") - if request.param == "starlette": - return get_starlette_app(client_id) - elif request.param == "fastapi": - return get_fastapi_app(client_id) - raise NotImplementedError - - -def get_starlette_app(client_id: str) -> Starlette: - from starlette.applications import Starlette - from starlette.background import BackgroundTask, BackgroundTasks - from starlette.requests import Request - from starlette.responses import PlainTextResponse - from starlette.routing import Route - - from starlette_apitally.middleware import ApitallyMiddleware - - background_task_mock = MagicMock() - - def foo(request: Request): - return PlainTextResponse("foo", background=BackgroundTasks([BackgroundTask(background_task_mock)])) - - def foo_bar(request: Request): - return PlainTextResponse(f"foo: {request.path_params['bar']}", background=BackgroundTask(background_task_mock)) - - def bar(request: Request): - return PlainTextResponse("bar") - - def baz(request: Request): - raise ValueError("baz") - - routes = [ - Route("/foo/", foo), - Route("/foo/{bar}/", foo_bar), - Route("/bar/", bar, methods=["POST"]), - Route("/baz/", baz, methods=["POST"]), - ] - app = Starlette(routes=routes) - app.add_middleware(ApitallyMiddleware, client_id=client_id) - app.state.background_task_mock = background_task_mock - return app - - -def get_fastapi_app(client_id: str) -> Starlette: - from fastapi import FastAPI - - from starlette_apitally.middleware import ApitallyMiddleware - - background_task_mock = MagicMock() - - app = FastAPI(title="Test App", description="A simple test app.", version="1.2.3") - app.add_middleware(ApitallyMiddleware, client_id=client_id) - app.state.background_task_mock = background_task_mock - - @app.get("/foo/") - def foo(background_tasks: BackgroundTasks): - background_tasks.add_task(background_task_mock) - return "foo" - - @app.get("/foo/{bar}/") - def foo_bar(bar: str, background_tasks: BackgroundTasks): - background_tasks.add_task(background_task_mock) - return f"foo: {bar}" - - @app.post("/bar/") - def bar(): - return "bar" - - @app.post("/baz/") - def baz(): - raise ValueError("baz") - - return app diff --git a/tests/test_app_info.py b/tests/test_app_info.py deleted file mode 100644 index 0533e00..0000000 --- a/tests/test_app_info.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from pytest_mock import MockerFixture - - -if TYPE_CHECKING: - from starlette.applications import Starlette - - -def test_get_app_info(app: Starlette, mocker: MockerFixture): - from starlette_apitally.app_info import get_app_info - - mocker.patch("starlette_apitally.middleware.ApitallyClient") - if app.middleware_stack is None: - app.middleware_stack = app.build_middleware_stack() - - app_info = get_app_info(app=app.middleware_stack, app_version=None, openapi_url="/openapi.json") - assert ("paths" in app_info) != ("openapi" in app_info) # only one, not both - - app_info = get_app_info(app=app.middleware_stack, app_version="1.2.3", openapi_url=None) - assert len(app_info["paths"]) == 4 - assert len(app_info["versions"]) > 1 - app_info["versions"]["app"] == "1.2.3" diff --git a/tests/test_client.py b/tests/test_client.py index 84f13b4..56a4119 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,7 +3,6 @@ import asyncio import json from typing import TYPE_CHECKING, AsyncIterator -from unittest.mock import MagicMock import pytest from pytest_httpx import HTTPXMock @@ -55,14 +54,12 @@ async def test_send_requests_data(client: ApitallyClient, httpx_mock: HTTPXMock) assert request_data["requests"][0]["request_count"] == 2 -async def test_send_app_info(client: ApitallyClient, httpx_mock: HTTPXMock, mocker: MockerFixture): +async def test_send_app_info(client: ApitallyClient, httpx_mock: HTTPXMock): from starlette_apitally.client import HUB_BASE_URL, HUB_VERSION - app_mock = MagicMock() httpx_mock.add_response() app_info = {"paths": [], "client_version": "1.0.0", "starlette_version": "0.28.0", "python_version": "3.11.4"} - mocker.patch("starlette_apitally.client.get_app_info", return_value=app_info) - client.send_app_info(app=app_mock, app_version="1.2.3", openapi_url="/openapi.json") + client.send_app_info(app_info=app_info) await asyncio.sleep(0.01) requests = httpx_mock.get_requests(url=f"{HUB_BASE_URL}/{HUB_VERSION}/{CLIENT_ID}/{ENV}/info") diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py new file mode 100644 index 0000000..ce302b9 --- /dev/null +++ b/tests/test_fastapi.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from pytest_mock import MockerFixture + + +if TYPE_CHECKING: + from fastapi import FastAPI + + +@pytest.fixture() +def app_with_auth() -> FastAPI: + from fastapi import Depends, FastAPI, Security + + from starlette_apitally.fastapi import Key, api_key_auth + + app = FastAPI() + + @app.get("/foo/") + def foo(key: Key = Security(api_key_auth, scopes=["foo"])): + return "foo" + + @app.get("/bar/") + def bar(key: Key = Security(api_key_auth, scopes=["bar"])): + return "bar" + + @app.get("/baz/", dependencies=[Depends(api_key_auth)]) + def baz(): + return "baz" + + return app + + +def test_api_key_auth(app_with_auth: FastAPI, mocker: MockerFixture): + from starlette.testclient import TestClient + + from starlette_apitally.keys import Key, Keys + + client = TestClient(app_with_auth) + keys = Keys() + keys.salt = "54fd2b80dbfeb87d924affbc91b77c76" + keys.keys = { + "bcf46e16814691991c8ed756a7ca3f9cef5644d4f55cd5aaaa5ab4ab4f809208": Key( + key_id=1, + name="Test key", + scopes=["foo"], + ) + } + headers = {"Authorization": "ApiKey 7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"} + mock = mocker.patch("starlette_apitally.fastapi.ApitallyClient.get_instance") + mock.return_value.keys = keys + + # Unauthenticated + response = client.get("/foo") + assert response.status_code == 401 + + response = client.get("/baz") + assert response.status_code == 401 + + # Invalid API key + response = client.get("/foo", headers={"Authorization": "ApiKey invalid"}) + assert response.status_code == 403 + + # Valid API key with required scope + response = client.get("/foo", headers=headers) + assert response.status_code == 200 + + # Valid API key, no scope required + response = client.get("/baz", headers=headers) + assert response.status_code == 200 + + # Valid API key without required scope + response = client.get("/bar", headers=headers) + assert response.status_code == 403 diff --git a/tests/test_keys.py b/tests/test_keys.py index caf0480..3bf1a5d 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -15,10 +15,13 @@ def test_keys(): { "bcf46e16814691991c8ed756a7ca3f9cef5644d4f55cd5aaaa5ab4ab4f809208": { "key_id": 1, + "name": "Test key 1", + "scopes": ["test"], "expires_in_seconds": 60, }, "ba05534cd4af03497416ef9db0a149a1234a4ded7d37a8bc3cde43f3ed56484a": { "key_id": 2, + "name": "Test key 2", "expires_in_seconds": 0, }, } @@ -27,7 +30,9 @@ def test_keys(): # Key with bcf46e16814691991c8ed756a7ca3f9cef5644d4f55cd5aaaa5ab4ab4f809208 is valid key = keys.get("7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI") assert key.key_id == 1 + assert key.name == "Test key 1" assert key.expires_at is not None + assert key.check_scopes(["test"]) # Key with hash ba05534cd4af03497416ef9db0a149a1234a4ded7d37a8bc3cde43f3ed56484a is expired key = keys.get("We6Yr7Z.fzj8t8TuYcTB9uOnpc2P7l4qlysIlT8q") @@ -36,3 +41,6 @@ def test_keys(): # Key does not exist key = keys.get("F9vNgPM.fiXFjMxmSn1TZeuyIm0CxF7gfmfrjKSZ") assert key is None + + used_key_ids = keys.get_and_reset_used_key_ids() + assert used_key_ids == [1] diff --git a/tests/test_middleware.py b/tests/test_middleware.py deleted file mode 100644 index 7508a8e..0000000 --- a/tests/test_middleware.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - - -if TYPE_CHECKING: - from starlette.applications import Starlette - - -def test_param_validation(app: Starlette, client_id: str): - from starlette_apitally.client import ApitallyClient - from starlette_apitally.middleware import ApitallyMiddleware - - ApitallyClient._instance = None - - with pytest.raises(ValueError): - ApitallyMiddleware(app, client_id="76b5zb91-a0a4-4ea0-a894-57d2b9fcb2c9") - with pytest.raises(ValueError): - ApitallyMiddleware(app, client_id=client_id, env="invalid.string") - with pytest.raises(ValueError): - ApitallyMiddleware(app, client_id=client_id, app_version="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") - with pytest.raises(ValueError): - ApitallyMiddleware(app, client_id=client_id, send_every=1) - - -def test_success(app: Starlette, mocker: MockerFixture): - from starlette.testclient import TestClient - - mock = mocker.patch("starlette_apitally.requests.Requests.log_request") - client = TestClient(app) - background_task_mock: MagicMock = app.state.background_task_mock # type: ignore[attr-defined] - - response = client.get("/foo/") - assert response.status_code == 200 - background_task_mock.assert_called_once() - mock.assert_called_once() - assert mock.call_args is not None - assert mock.call_args.kwargs["method"] == "GET" - assert mock.call_args.kwargs["path"] == "/foo/" - assert mock.call_args.kwargs["status_code"] == 200 - assert mock.call_args.kwargs["response_time"] > 0 - - response = client.get("/foo/123/") - assert response.status_code == 200 - assert background_task_mock.call_count == 2 - assert mock.call_count == 2 - assert mock.call_args is not None - assert mock.call_args.kwargs["path"] == "/foo/{bar}/" - - response = client.post("/bar/") - assert response.status_code == 200 - assert mock.call_count == 3 - assert mock.call_args is not None - assert mock.call_args.kwargs["method"] == "POST" - - -def test_error(app: Starlette, mocker: MockerFixture): - from starlette.testclient import TestClient - - mocker.patch("starlette_apitally.client.ApitallyClient.send_app_info") - mock = mocker.patch("starlette_apitally.requests.Requests.log_request") - client = TestClient(app, raise_server_exceptions=False) - - response = client.post("/baz/") - assert response.status_code == 500 - mock.assert_called_once() - assert mock.call_args is not None - assert mock.call_args.kwargs["method"] == "POST" - assert mock.call_args.kwargs["path"] == "/baz/" - assert mock.call_args.kwargs["status_code"] == 500 - assert mock.call_args.kwargs["response_time"] > 0 - - -def test_unhandled(app: Starlette, mocker: MockerFixture): - from starlette.testclient import TestClient - - mocker.patch("starlette_apitally.client.ApitallyClient.send_app_info") - mock = mocker.patch("starlette_apitally.requests.Requests.log_request") - client = TestClient(app) - - response = client.post("/xxx/") - assert response.status_code == 404 - mock.assert_not_called() diff --git a/tests/test_starlette.py b/tests/test_starlette.py new file mode 100644 index 0000000..1c8333a --- /dev/null +++ b/tests/test_starlette.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import asyncio +from asyncio import AbstractEventLoop +from importlib.util import find_spec +from typing import TYPE_CHECKING, Iterator +from unittest.mock import MagicMock + +import pytest +from pytest import FixtureRequest +from pytest_mock import MockerFixture +from starlette.background import BackgroundTasks # import here to avoid pydantic error + + +if TYPE_CHECKING: + from starlette.applications import Starlette + + +CLIENT_ID = "76b5cb91-a0a4-4ea0-a894-57d2b9fcb2c9" +ENV = "default" + + +@pytest.fixture(scope="module") +def event_loop() -> Iterator[AbstractEventLoop]: + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture( + scope="module", + params=["starlette", "fastapi"] if find_spec("fastapi") is not None else ["starlette"], +) +async def app(request: FixtureRequest, module_mocker: MockerFixture) -> Starlette: + module_mocker.patch("starlette_apitally.client.ApitallyClient.start_sync_loop") + module_mocker.patch("starlette_apitally.client.ApitallyClient.send_app_info") + if request.param == "starlette": + return get_starlette_app() + elif request.param == "fastapi": + return get_fastapi_app() + raise NotImplementedError + + +@pytest.fixture() +def app_with_auth() -> Starlette: + from starlette.applications import Starlette + from starlette.authentication import requires + from starlette.middleware.authentication import AuthenticationMiddleware + from starlette.requests import Request + from starlette.responses import JSONResponse, PlainTextResponse + from starlette.routing import Route + + from starlette_apitally.starlette import ApitallyKeysBackend + + @requires(["authenticated", "foo"]) + def foo(request: Request): + assert request.user.is_authenticated + return PlainTextResponse("foo") + + @requires(["authenticated", "bar"]) + def bar(request: Request): + return PlainTextResponse("bar") + + @requires("authenticated") + def baz(request: Request): + return JSONResponse( + { + "key_id": int(request.user.identity), + "key_name": request.user.display_name, + "key_scopes": request.auth.scopes, + } + ) + + routes = [ + Route("/foo/", foo), + Route("/bar/", bar), + Route("/baz/", baz), + ] + app = Starlette(routes=routes) + app.add_middleware(AuthenticationMiddleware, backend=ApitallyKeysBackend()) + return app + + +def get_starlette_app() -> Starlette: + from starlette.applications import Starlette + from starlette.background import BackgroundTask, BackgroundTasks + from starlette.requests import Request + from starlette.responses import PlainTextResponse + from starlette.routing import Route + + from starlette_apitally.starlette import ApitallyMiddleware + + background_task_mock = MagicMock() + + def foo(request: Request): + return PlainTextResponse("foo", background=BackgroundTasks([BackgroundTask(background_task_mock)])) + + def foo_bar(request: Request): + return PlainTextResponse(f"foo: {request.path_params['bar']}", background=BackgroundTask(background_task_mock)) + + def bar(request: Request): + return PlainTextResponse("bar") + + def baz(request: Request): + raise ValueError("baz") + + routes = [ + Route("/foo/", foo), + Route("/foo/{bar}/", foo_bar), + Route("/bar/", bar, methods=["POST"]), + Route("/baz/", baz, methods=["POST"]), + ] + app = Starlette(routes=routes) + app.add_middleware(ApitallyMiddleware, client_id=CLIENT_ID, env=ENV) + app.state.background_task_mock = background_task_mock + return app + + +def get_fastapi_app() -> Starlette: + from fastapi import FastAPI + + from starlette_apitally.fastapi import ApitallyMiddleware + + background_task_mock = MagicMock() + + app = FastAPI(title="Test App", description="A simple test app.", version="1.2.3") + app.add_middleware(ApitallyMiddleware, client_id=CLIENT_ID, env=ENV) + app.state.background_task_mock = background_task_mock + + @app.get("/foo/") + def foo(background_tasks: BackgroundTasks): + background_tasks.add_task(background_task_mock) + return "foo" + + @app.get("/foo/{bar}/") + def foo_bar(bar: str, background_tasks: BackgroundTasks): + background_tasks.add_task(background_task_mock) + return f"foo: {bar}" + + @app.post("/bar/") + def bar(): + return "bar" + + @app.post("/baz/") + def baz(): + raise ValueError("baz") + + return app + + +def test_middleware_param_validation(app: Starlette): + from starlette_apitally.client import ApitallyClient + from starlette_apitally.starlette import ApitallyMiddleware + + ApitallyClient._instance = None + + with pytest.raises(ValueError): + ApitallyMiddleware(app, client_id="76b5zb91-a0a4-4ea0-a894-57d2b9fcb2c9") + with pytest.raises(ValueError): + ApitallyMiddleware(app, client_id=CLIENT_ID, env="invalid.string") + with pytest.raises(ValueError): + ApitallyMiddleware(app, client_id=CLIENT_ID, app_version="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + with pytest.raises(ValueError): + ApitallyMiddleware(app, client_id=CLIENT_ID, send_every=1) + + +def test_middleware_requests_ok(app: Starlette, mocker: MockerFixture): + from starlette.testclient import TestClient + + mock = mocker.patch("starlette_apitally.requests.Requests.log_request") + client = TestClient(app) + background_task_mock: MagicMock = app.state.background_task_mock # type: ignore[attr-defined] + + response = client.get("/foo/") + assert response.status_code == 200 + background_task_mock.assert_called_once() + mock.assert_called_once() + assert mock.call_args is not None + assert mock.call_args.kwargs["method"] == "GET" + assert mock.call_args.kwargs["path"] == "/foo/" + assert mock.call_args.kwargs["status_code"] == 200 + assert mock.call_args.kwargs["response_time"] > 0 + + response = client.get("/foo/123/") + assert response.status_code == 200 + assert background_task_mock.call_count == 2 + assert mock.call_count == 2 + assert mock.call_args is not None + assert mock.call_args.kwargs["path"] == "/foo/{bar}/" + + response = client.post("/bar/") + assert response.status_code == 200 + assert mock.call_count == 3 + assert mock.call_args is not None + assert mock.call_args.kwargs["method"] == "POST" + + +def test_middleware_requests_error(app: Starlette, mocker: MockerFixture): + from starlette.testclient import TestClient + + mocker.patch("starlette_apitally.client.ApitallyClient.send_app_info") + mock = mocker.patch("starlette_apitally.requests.Requests.log_request") + client = TestClient(app, raise_server_exceptions=False) + + response = client.post("/baz/") + assert response.status_code == 500 + mock.assert_called_once() + assert mock.call_args is not None + assert mock.call_args.kwargs["method"] == "POST" + assert mock.call_args.kwargs["path"] == "/baz/" + assert mock.call_args.kwargs["status_code"] == 500 + assert mock.call_args.kwargs["response_time"] > 0 + + +def test_middleware_requests_unhandled(app: Starlette, mocker: MockerFixture): + from starlette.testclient import TestClient + + mocker.patch("starlette_apitally.client.ApitallyClient.send_app_info") + mock = mocker.patch("starlette_apitally.requests.Requests.log_request") + client = TestClient(app) + + response = client.post("/xxx/") + assert response.status_code == 404 + mock.assert_not_called() + + +def test_keys_auth_backend(app_with_auth: Starlette, mocker: MockerFixture): + from starlette.testclient import TestClient + + from starlette_apitally.keys import Key, Keys + + client = TestClient(app_with_auth) + keys = Keys() + keys.salt = "54fd2b80dbfeb87d924affbc91b77c76" + keys.keys = { + "bcf46e16814691991c8ed756a7ca3f9cef5644d4f55cd5aaaa5ab4ab4f809208": Key( + key_id=1, + name="Test key", + scopes=["foo"], + ) + } + headers = {"Authorization": "ApiKey 7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"} + mock = mocker.patch("starlette_apitally.fastapi.ApitallyClient.get_instance") + mock.return_value.keys = keys + + # Unauthenticated + response = client.get("/foo") + assert response.status_code == 403 + + response = client.get("/baz") + assert response.status_code == 403 + + # Invalid API key + response = client.get("/foo", headers={"Authorization": "ApiKey invalid"}) + assert response.status_code == 400 + + # Valid API key with required scope + response = client.get("/foo", headers=headers) + assert response.status_code == 200 + + # Valid API key, no scope required + response = client.get("/baz", headers=headers) + assert response.status_code == 200 + response_data = response.json() + assert response_data["key_id"] == 1 + assert response_data["key_name"] == "Test key" + assert response_data["key_scopes"] == ["authenticated", "foo"] + + # Valid API key without required scope + response = client.get("/bar", headers=headers) + assert response.status_code == 403 + + +def test_get_app_info(app: Starlette, mocker: MockerFixture): + from starlette_apitally.starlette import _get_app_info + + mocker.patch("starlette_apitally.starlette.ApitallyClient") + if app.middleware_stack is None: + app.middleware_stack = app.build_middleware_stack() + + app_info = _get_app_info(app=app.middleware_stack, app_version=None, openapi_url="/openapi.json") + assert ("paths" in app_info) != ("openapi" in app_info) # only one, not both + + app_info = _get_app_info(app=app.middleware_stack, app_version="1.2.3", openapi_url=None) + assert len(app_info["paths"]) == 4 + assert len(app_info["versions"]) > 1 + app_info["versions"]["app"] == "1.2.3"