From e40d14003a4fc1ec4d203e4c99b705ebb577221c Mon Sep 17 00:00:00 2001 From: Kelly Walker <kelly.walker@3lc.ai> Date: Tue, 6 Aug 2024 06:42:34 -0500 Subject: [PATCH] feat(integrations): Update StarliteIntegration to be more in line with new LitestarIntegration (#3384) The new LitestarIntegration was initially ported from the StarliteIntegration, but then had a thorough code review that resulted in use of type comments instead of type hints (the convention used throughout the repo), more concise code in several places, and additional/updated tests. This PR backports those improvements to the StarliteIntegration. See #3358. --------- Co-authored-by: Anton Pirker <anton.pirker@sentry.io> --- sentry_sdk/integrations/starlite.py | 113 ++++---- tests/integrations/starlite/test_starlite.py | 264 ++++++++++++------- 2 files changed, 229 insertions(+), 148 deletions(-) diff --git a/sentry_sdk/integrations/starlite.py b/sentry_sdk/integrations/starlite.py index 07259563e0..8e72751e95 100644 --- a/sentry_sdk/integrations/starlite.py +++ b/sentry_sdk/integrations/starlite.py @@ -1,6 +1,5 @@ -from typing import TYPE_CHECKING - import sentry_sdk +from sentry_sdk._types import TYPE_CHECKING from sentry_sdk.consts import OP from sentry_sdk.integrations import DidNotEnable, Integration from sentry_sdk.integrations.asgi import SentryAsgiMiddleware @@ -20,26 +19,26 @@ from starlite.routes.http import HTTPRoute # type: ignore from starlite.utils import ConnectionDataExtractor, is_async_callable, Ref # type: ignore from pydantic import BaseModel # type: ignore - - if TYPE_CHECKING: - from typing import Any, Dict, List, Optional, Union - from starlite.types import ( # type: ignore - ASGIApp, - Hint, - HTTPReceiveMessage, - HTTPScope, - Message, - Middleware, - Receive, - Scope as StarliteScope, - Send, - WebSocketReceiveMessage, - ) - from starlite import MiddlewareProtocol - from sentry_sdk._types import Event except ImportError: raise DidNotEnable("Starlite is not installed") +if TYPE_CHECKING: + from typing import Any, Optional, Union + from starlite.types import ( # type: ignore + ASGIApp, + Hint, + HTTPReceiveMessage, + HTTPScope, + Message, + Middleware, + Receive, + Scope as StarliteScope, + Send, + WebSocketReceiveMessage, + ) + from starlite import MiddlewareProtocol + from sentry_sdk._types import Event + _DEFAULT_TRANSACTION_NAME = "generic Starlite request" @@ -49,14 +48,16 @@ class StarliteIntegration(Integration): origin = f"auto.http.{identifier}" @staticmethod - def setup_once() -> None: + def setup_once(): + # type: () -> None patch_app_init() patch_middlewares() patch_http_route_handle() class SentryStarliteASGIMiddleware(SentryAsgiMiddleware): - def __init__(self, app: "ASGIApp", span_origin: str = StarliteIntegration.origin): + def __init__(self, app, span_origin=StarliteIntegration.origin): + # type: (ASGIApp, str) -> None super().__init__( app=app, unsafe_context_data=False, @@ -66,7 +67,8 @@ def __init__(self, app: "ASGIApp", span_origin: str = StarliteIntegration.origin ) -def patch_app_init() -> None: +def patch_app_init(): + # type: () -> None """ Replaces the Starlite class's `__init__` function in order to inject `after_exception` handlers and set the `SentryStarliteASGIMiddleware` as the outmost middleware in the stack. @@ -76,7 +78,9 @@ def patch_app_init() -> None: """ old__init__ = Starlite.__init__ - def injection_wrapper(self: "Starlite", *args: "Any", **kwargs: "Any") -> None: + @ensure_integration_enabled(StarliteIntegration, old__init__) + def injection_wrapper(self, *args, **kwargs): + # type: (Starlite, *Any, **Any) -> None after_exception = kwargs.pop("after_exception", []) kwargs.update( after_exception=[ @@ -90,26 +94,30 @@ def injection_wrapper(self: "Starlite", *args: "Any", **kwargs: "Any") -> None: ) SentryStarliteASGIMiddleware.__call__ = SentryStarliteASGIMiddleware._run_asgi3 # type: ignore - middleware = kwargs.pop("middleware", None) or [] + middleware = kwargs.get("middleware") or [] kwargs["middleware"] = [SentryStarliteASGIMiddleware, *middleware] old__init__(self, *args, **kwargs) Starlite.__init__ = injection_wrapper -def patch_middlewares() -> None: - old__resolve_middleware_stack = BaseRouteHandler.resolve_middleware +def patch_middlewares(): + # type: () -> None + old_resolve_middleware_stack = BaseRouteHandler.resolve_middleware - def resolve_middleware_wrapper(self: "Any") -> "List[Middleware]": + @ensure_integration_enabled(StarliteIntegration, old_resolve_middleware_stack) + def resolve_middleware_wrapper(self): + # type: (BaseRouteHandler) -> list[Middleware] return [ enable_span_for_middleware(middleware) - for middleware in old__resolve_middleware_stack(self) + for middleware in old_resolve_middleware_stack(self) ] BaseRouteHandler.resolve_middleware = resolve_middleware_wrapper -def enable_span_for_middleware(middleware: "Middleware") -> "Middleware": +def enable_span_for_middleware(middleware): + # type: (Middleware) -> Middleware if ( not hasattr(middleware, "__call__") # noqa: B004 or middleware is SentryStarliteASGIMiddleware @@ -117,16 +125,12 @@ def enable_span_for_middleware(middleware: "Middleware") -> "Middleware": return middleware if isinstance(middleware, DefineMiddleware): - old_call: "ASGIApp" = middleware.middleware.__call__ + old_call = middleware.middleware.__call__ # type: ASGIApp else: old_call = middleware.__call__ - async def _create_span_call( - self: "MiddlewareProtocol", - scope: "StarliteScope", - receive: "Receive", - send: "Send", - ) -> None: + async def _create_span_call(self, scope, receive, send): + # type: (MiddlewareProtocol, StarliteScope, Receive, Send) -> None if sentry_sdk.get_client().get_integration(StarliteIntegration) is None: return await old_call(self, scope, receive, send) @@ -139,9 +143,10 @@ async def _create_span_call( middleware_span.set_tag("starlite.middleware_name", middleware_name) # Creating spans for the "receive" callback - async def _sentry_receive( - *args: "Any", **kwargs: "Any" - ) -> "Union[HTTPReceiveMessage, WebSocketReceiveMessage]": + async def _sentry_receive(*args, **kwargs): + # type: (*Any, **Any) -> Union[HTTPReceiveMessage, WebSocketReceiveMessage] + if sentry_sdk.get_client().get_integration(StarliteIntegration) is None: + return await receive(*args, **kwargs) with sentry_sdk.start_span( op=OP.MIDDLEWARE_STARLITE_RECEIVE, description=getattr(receive, "__qualname__", str(receive)), @@ -155,7 +160,10 @@ async def _sentry_receive( new_receive = _sentry_receive if not receive_patched else receive # Creating spans for the "send" callback - async def _sentry_send(message: "Message") -> None: + async def _sentry_send(message): + # type: (Message) -> None + if sentry_sdk.get_client().get_integration(StarliteIntegration) is None: + return await send(message) with sentry_sdk.start_span( op=OP.MIDDLEWARE_STARLITE_SEND, description=getattr(send, "__qualname__", str(send)), @@ -181,19 +189,19 @@ async def _sentry_send(message: "Message") -> None: return middleware -def patch_http_route_handle() -> None: +def patch_http_route_handle(): + # type: () -> None old_handle = HTTPRoute.handle - async def handle_wrapper( - self: "HTTPRoute", scope: "HTTPScope", receive: "Receive", send: "Send" - ) -> None: + async def handle_wrapper(self, scope, receive, send): + # type: (HTTPRoute, HTTPScope, Receive, Send) -> None if sentry_sdk.get_client().get_integration(StarliteIntegration) is None: return await old_handle(self, scope, receive, send) sentry_scope = sentry_sdk.get_isolation_scope() - request: "Request[Any, Any]" = scope["app"].request_class( + request = scope["app"].request_class( scope=scope, receive=receive, send=send - ) + ) # type: Request[Any, Any] extracted_request_data = ConnectionDataExtractor( parse_body=True, parse_query=True )(request) @@ -201,7 +209,8 @@ async def handle_wrapper( request_data = await body - def event_processor(event: "Event", _: "Hint") -> "Event": + def event_processor(event, _): + # type: (Event, Hint) -> Event route_handler = scope.get("route_handler") request_info = event.get("request", {}) @@ -244,8 +253,9 @@ def event_processor(event: "Event", _: "Hint") -> "Event": HTTPRoute.handle = handle_wrapper -def retrieve_user_from_scope(scope: "StarliteScope") -> "Optional[Dict[str, Any]]": - scope_user = scope.get("user", {}) +def retrieve_user_from_scope(scope): + # type: (StarliteScope) -> Optional[dict[str, Any]] + scope_user = scope.get("user") if not scope_user: return None if isinstance(scope_user, dict): @@ -263,8 +273,9 @@ def retrieve_user_from_scope(scope: "StarliteScope") -> "Optional[Dict[str, Any] @ensure_integration_enabled(StarliteIntegration) -def exception_handler(exc: Exception, scope: "StarliteScope", _: "State") -> None: - user_info: "Optional[Dict[str, Any]]" = None +def exception_handler(exc, scope, _): + # type: (Exception, StarliteScope, State) -> None + user_info = None # type: Optional[dict[str, Any]] if should_send_default_pii(): user_info = retrieve_user_from_scope(scope) if user_info and isinstance(user_info, dict): diff --git a/tests/integrations/starlite/test_starlite.py b/tests/integrations/starlite/test_starlite.py index 45075b5199..2c3aa704f5 100644 --- a/tests/integrations/starlite/test_starlite.py +++ b/tests/integrations/starlite/test_starlite.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import pytest @@ -13,50 +14,6 @@ from starlite.testing import TestClient -class SampleMiddleware(AbstractMiddleware): - async def __call__(self, scope, receive, send) -> None: - async def do_stuff(message): - if message["type"] == "http.response.start": - # do something here. - pass - await send(message) - - await self.app(scope, receive, do_stuff) - - -class SampleReceiveSendMiddleware(AbstractMiddleware): - async def __call__(self, scope, receive, send): - message = await receive() - assert message - assert message["type"] == "http.request" - - send_output = await send({"type": "something-unimportant"}) - assert send_output is None - - await self.app(scope, receive, send) - - -class SamplePartialReceiveSendMiddleware(AbstractMiddleware): - async def __call__(self, scope, receive, send): - message = await receive() - assert message - assert message["type"] == "http.request" - - send_output = await send({"type": "something-unimportant"}) - assert send_output is None - - async def my_receive(*args, **kwargs): - pass - - async def my_send(*args, **kwargs): - pass - - partial_receive = functools.partial(my_receive) - partial_send = functools.partial(my_send) - - await self.app(scope, partial_receive, partial_send) - - def starlite_app_factory(middleware=None, debug=True, exception_handlers=None): class MyController(Controller): path = "/controller" @@ -66,7 +23,7 @@ async def controller_error(self) -> None: raise Exception("Whoa") @get("/some_url") - async def homepage_handler() -> Dict[str, Any]: + async def homepage_handler() -> "Dict[str, Any]": 1 / 0 return {"status": "ok"} @@ -75,12 +32,12 @@ async def custom_error() -> Any: raise Exception("Too Hot") @get("/message") - async def message() -> Dict[str, Any]: + async def message() -> "Dict[str, Any]": capture_message("hi") return {"status": "ok"} @get("/message/{message_id:str}") - async def message_with_id() -> Dict[str, Any]: + async def message_with_id() -> "Dict[str, Any]": capture_message("hi") return {"status": "ok"} @@ -151,8 +108,8 @@ def test_catch_exceptions( assert str(exc) == expected_message (event,) = events - assert event["exception"]["values"][0]["mechanism"]["type"] == "starlite" assert event["transaction"] == expected_tx_name + assert event["exception"]["values"][0]["mechanism"]["type"] == "starlite" def test_middleware_spans(sentry_init, capture_events): @@ -177,40 +134,50 @@ def test_middleware_spans(sentry_init, capture_events): client = TestClient( starlite_app, raise_server_exceptions=False, base_url="http://testserver.local" ) - try: - client.get("/message") - except Exception: - pass + client.get("/message") (_, transaction_event) = events - expected = ["SessionMiddleware", "LoggingMiddleware", "RateLimitMiddleware"] + expected = {"SessionMiddleware", "LoggingMiddleware", "RateLimitMiddleware"} + found = set() + + starlite_spans = ( + span + for span in transaction_event["spans"] + if span["op"] == "middleware.starlite" + ) - idx = 0 - for span in transaction_event["spans"]: - if span["op"] == "middleware.starlite": - assert span["description"] == expected[idx] - assert span["tags"]["starlite.middleware_name"] == expected[idx] - idx += 1 + for span in starlite_spans: + assert span["description"] in expected + assert span["description"] not in found + found.add(span["description"]) + assert span["description"] == span["tags"]["starlite.middleware_name"] def test_middleware_callback_spans(sentry_init, capture_events): + class SampleMiddleware(AbstractMiddleware): + async def __call__(self, scope, receive, send) -> None: + async def do_stuff(message): + if message["type"] == "http.response.start": + # do something here. + pass + await send(message) + + await self.app(scope, receive, do_stuff) + sentry_init( traces_sample_rate=1.0, integrations=[StarliteIntegration()], ) - starlette_app = starlite_app_factory(middleware=[SampleMiddleware]) + starlite_app = starlite_app_factory(middleware=[SampleMiddleware]) events = capture_events() - client = TestClient(starlette_app, raise_server_exceptions=False) - try: - client.get("/message") - except Exception: - pass + client = TestClient(starlite_app, raise_server_exceptions=False) + client.get("/message") - (_, transaction_event) = events + (_, transaction_events) = events - expected = [ + expected_starlite_spans = [ { "op": "middleware.starlite", "description": "SampleMiddleware", @@ -227,47 +194,86 @@ def test_middleware_callback_spans(sentry_init, capture_events): "tags": {"starlite.middleware_name": "SampleMiddleware"}, }, ] - for idx, span in enumerate(transaction_event["spans"]): - assert span["op"] == expected[idx]["op"] - assert span["description"] == expected[idx]["description"] - assert span["tags"] == expected[idx]["tags"] + + def is_matching_span(expected_span, actual_span): + return ( + expected_span["op"] == actual_span["op"] + and expected_span["description"] == actual_span["description"] + and expected_span["tags"] == actual_span["tags"] + ) + + actual_starlite_spans = list( + span + for span in transaction_events["spans"] + if "middleware.starlite" in span["op"] + ) + assert len(actual_starlite_spans) == 3 + + for expected_span in expected_starlite_spans: + assert any( + is_matching_span(expected_span, actual_span) + for actual_span in actual_starlite_spans + ) def test_middleware_receive_send(sentry_init, capture_events): + class SampleReceiveSendMiddleware(AbstractMiddleware): + async def __call__(self, scope, receive, send): + message = await receive() + assert message + assert message["type"] == "http.request" + + send_output = await send({"type": "something-unimportant"}) + assert send_output is None + + await self.app(scope, receive, send) + sentry_init( traces_sample_rate=1.0, integrations=[StarliteIntegration()], ) - starlette_app = starlite_app_factory(middleware=[SampleReceiveSendMiddleware]) + starlite_app = starlite_app_factory(middleware=[SampleReceiveSendMiddleware]) - client = TestClient(starlette_app, raise_server_exceptions=False) - try: - # NOTE: the assert statements checking - # for correct behaviour are in `SampleReceiveSendMiddleware`! - client.get("/message") - except Exception: - pass + client = TestClient(starlite_app, raise_server_exceptions=False) + # See SampleReceiveSendMiddleware.__call__ above for assertions of correct behavior + client.get("/message") def test_middleware_partial_receive_send(sentry_init, capture_events): + class SamplePartialReceiveSendMiddleware(AbstractMiddleware): + async def __call__(self, scope, receive, send): + message = await receive() + assert message + assert message["type"] == "http.request" + + send_output = await send({"type": "something-unimportant"}) + assert send_output is None + + async def my_receive(*args, **kwargs): + pass + + async def my_send(*args, **kwargs): + pass + + partial_receive = functools.partial(my_receive) + partial_send = functools.partial(my_send) + + await self.app(scope, partial_receive, partial_send) + sentry_init( traces_sample_rate=1.0, integrations=[StarliteIntegration()], ) - starlette_app = starlite_app_factory( - middleware=[SamplePartialReceiveSendMiddleware] - ) + starlite_app = starlite_app_factory(middleware=[SamplePartialReceiveSendMiddleware]) events = capture_events() - client = TestClient(starlette_app, raise_server_exceptions=False) - try: - client.get("/message") - except Exception: - pass + client = TestClient(starlite_app, raise_server_exceptions=False) + # See SamplePartialReceiveSendMiddleware.__call__ above for assertions of correct behavior + client.get("/message") - (_, transaction_event) = events + (_, transaction_events) = events - expected = [ + expected_starlite_spans = [ { "op": "middleware.starlite", "description": "SamplePartialReceiveSendMiddleware", @@ -285,10 +291,25 @@ def test_middleware_partial_receive_send(sentry_init, capture_events): }, ] - for idx, span in enumerate(transaction_event["spans"]): - assert span["op"] == expected[idx]["op"] - assert span["description"].startswith(expected[idx]["description"]) - assert span["tags"] == expected[idx]["tags"] + def is_matching_span(expected_span, actual_span): + return ( + expected_span["op"] == actual_span["op"] + and actual_span["description"].startswith(expected_span["description"]) + and expected_span["tags"] == actual_span["tags"] + ) + + actual_starlite_spans = list( + span + for span in transaction_events["spans"] + if "middleware.starlite" in span["op"] + ) + assert len(actual_starlite_spans) == 3 + + for expected_span in expected_starlite_spans: + assert any( + is_matching_span(expected_span, actual_span) + for actual_span in actual_starlite_spans + ) def test_span_origin(sentry_init, capture_events): @@ -313,13 +334,62 @@ def test_span_origin(sentry_init, capture_events): client = TestClient( starlite_app, raise_server_exceptions=False, base_url="http://testserver.local" ) - try: - client.get("/message") - except Exception: - pass + client.get("/message") (_, event) = events assert event["contexts"]["trace"]["origin"] == "auto.http.starlite" for span in event["spans"]: assert span["origin"] == "auto.http.starlite" + + +@pytest.mark.parametrize( + "is_send_default_pii", + [ + True, + False, + ], + ids=[ + "send_default_pii=True", + "send_default_pii=False", + ], +) +def test_starlite_scope_user_on_exception_event( + sentry_init, capture_exceptions, capture_events, is_send_default_pii +): + class TestUserMiddleware(AbstractMiddleware): + async def __call__(self, scope, receive, send): + scope["user"] = { + "email": "lennon@thebeatles.com", + "username": "john", + "id": "1", + } + await self.app(scope, receive, send) + + sentry_init( + integrations=[StarliteIntegration()], send_default_pii=is_send_default_pii + ) + starlite_app = starlite_app_factory(middleware=[TestUserMiddleware]) + exceptions = capture_exceptions() + events = capture_events() + + # This request intentionally raises an exception + client = TestClient(starlite_app) + try: + client.get("/some_url") + except Exception: + pass + + assert len(exceptions) == 1 + assert len(events) == 1 + (event,) = events + + if is_send_default_pii: + assert "user" in event + assert event["user"] == { + "email": "lennon@thebeatles.com", + "username": "john", + "id": "1", + } + else: + assert "user" not in event