From e40d14003a4fc1ec4d203e4c99b705ebb577221c Mon Sep 17 00:00:00 2001
From: Kelly Walker <kelly.walker@3lc.ai>
Date: Tue, 6 Aug 2024 06:42:34 -0500
Subject: [PATCH] feat(integrations): Update StarliteIntegration to be more in
 line with new LitestarIntegration (#3384)

The new LitestarIntegration was initially ported from the StarliteIntegration, but then had a thorough code review that resulted in use of type comments instead of type hints (the convention used throughout the repo), more concise code in several places, and additional/updated tests. This PR backports those improvements to the StarliteIntegration. See #3358.

---------

Co-authored-by: Anton Pirker <anton.pirker@sentry.io>
---
 sentry_sdk/integrations/starlite.py          | 113 ++++----
 tests/integrations/starlite/test_starlite.py | 264 ++++++++++++-------
 2 files changed, 229 insertions(+), 148 deletions(-)

diff --git a/sentry_sdk/integrations/starlite.py b/sentry_sdk/integrations/starlite.py
index 07259563e0..8e72751e95 100644
--- a/sentry_sdk/integrations/starlite.py
+++ b/sentry_sdk/integrations/starlite.py
@@ -1,6 +1,5 @@
-from typing import TYPE_CHECKING
-
 import sentry_sdk
+from sentry_sdk._types import TYPE_CHECKING
 from sentry_sdk.consts import OP
 from sentry_sdk.integrations import DidNotEnable, Integration
 from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
@@ -20,26 +19,26 @@
     from starlite.routes.http import HTTPRoute  # type: ignore
     from starlite.utils import ConnectionDataExtractor, is_async_callable, Ref  # type: ignore
     from pydantic import BaseModel  # type: ignore
-
-    if TYPE_CHECKING:
-        from typing import Any, Dict, List, Optional, Union
-        from starlite.types import (  # type: ignore
-            ASGIApp,
-            Hint,
-            HTTPReceiveMessage,
-            HTTPScope,
-            Message,
-            Middleware,
-            Receive,
-            Scope as StarliteScope,
-            Send,
-            WebSocketReceiveMessage,
-        )
-        from starlite import MiddlewareProtocol
-        from sentry_sdk._types import Event
 except ImportError:
     raise DidNotEnable("Starlite is not installed")
 
+if TYPE_CHECKING:
+    from typing import Any, Optional, Union
+    from starlite.types import (  # type: ignore
+        ASGIApp,
+        Hint,
+        HTTPReceiveMessage,
+        HTTPScope,
+        Message,
+        Middleware,
+        Receive,
+        Scope as StarliteScope,
+        Send,
+        WebSocketReceiveMessage,
+    )
+    from starlite import MiddlewareProtocol
+    from sentry_sdk._types import Event
+
 
 _DEFAULT_TRANSACTION_NAME = "generic Starlite request"
 
@@ -49,14 +48,16 @@ class StarliteIntegration(Integration):
     origin = f"auto.http.{identifier}"
 
     @staticmethod
-    def setup_once() -> None:
+    def setup_once():
+        # type: () -> None
         patch_app_init()
         patch_middlewares()
         patch_http_route_handle()
 
 
 class SentryStarliteASGIMiddleware(SentryAsgiMiddleware):
-    def __init__(self, app: "ASGIApp", span_origin: str = StarliteIntegration.origin):
+    def __init__(self, app, span_origin=StarliteIntegration.origin):
+        # type: (ASGIApp, str) -> None
         super().__init__(
             app=app,
             unsafe_context_data=False,
@@ -66,7 +67,8 @@ def __init__(self, app: "ASGIApp", span_origin: str = StarliteIntegration.origin
         )
 
 
-def patch_app_init() -> None:
+def patch_app_init():
+    # type: () -> None
     """
     Replaces the Starlite class's `__init__` function in order to inject `after_exception` handlers and set the
     `SentryStarliteASGIMiddleware` as the outmost middleware in the stack.
@@ -76,7 +78,9 @@ def patch_app_init() -> None:
     """
     old__init__ = Starlite.__init__
 
-    def injection_wrapper(self: "Starlite", *args: "Any", **kwargs: "Any") -> None:
+    @ensure_integration_enabled(StarliteIntegration, old__init__)
+    def injection_wrapper(self, *args, **kwargs):
+        # type: (Starlite, *Any, **Any) -> None
         after_exception = kwargs.pop("after_exception", [])
         kwargs.update(
             after_exception=[
@@ -90,26 +94,30 @@ def injection_wrapper(self: "Starlite", *args: "Any", **kwargs: "Any") -> None:
         )
 
         SentryStarliteASGIMiddleware.__call__ = SentryStarliteASGIMiddleware._run_asgi3  # type: ignore
-        middleware = kwargs.pop("middleware", None) or []
+        middleware = kwargs.get("middleware") or []
         kwargs["middleware"] = [SentryStarliteASGIMiddleware, *middleware]
         old__init__(self, *args, **kwargs)
 
     Starlite.__init__ = injection_wrapper
 
 
-def patch_middlewares() -> None:
-    old__resolve_middleware_stack = BaseRouteHandler.resolve_middleware
+def patch_middlewares():
+    # type: () -> None
+    old_resolve_middleware_stack = BaseRouteHandler.resolve_middleware
 
-    def resolve_middleware_wrapper(self: "Any") -> "List[Middleware]":
+    @ensure_integration_enabled(StarliteIntegration, old_resolve_middleware_stack)
+    def resolve_middleware_wrapper(self):
+        # type: (BaseRouteHandler) -> list[Middleware]
         return [
             enable_span_for_middleware(middleware)
-            for middleware in old__resolve_middleware_stack(self)
+            for middleware in old_resolve_middleware_stack(self)
         ]
 
     BaseRouteHandler.resolve_middleware = resolve_middleware_wrapper
 
 
-def enable_span_for_middleware(middleware: "Middleware") -> "Middleware":
+def enable_span_for_middleware(middleware):
+    # type: (Middleware) -> Middleware
     if (
         not hasattr(middleware, "__call__")  # noqa: B004
         or middleware is SentryStarliteASGIMiddleware
@@ -117,16 +125,12 @@ def enable_span_for_middleware(middleware: "Middleware") -> "Middleware":
         return middleware
 
     if isinstance(middleware, DefineMiddleware):
-        old_call: "ASGIApp" = middleware.middleware.__call__
+        old_call = middleware.middleware.__call__  # type: ASGIApp
     else:
         old_call = middleware.__call__
 
-    async def _create_span_call(
-        self: "MiddlewareProtocol",
-        scope: "StarliteScope",
-        receive: "Receive",
-        send: "Send",
-    ) -> None:
+    async def _create_span_call(self, scope, receive, send):
+        # type: (MiddlewareProtocol, StarliteScope, Receive, Send) -> None
         if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
             return await old_call(self, scope, receive, send)
 
@@ -139,9 +143,10 @@ async def _create_span_call(
             middleware_span.set_tag("starlite.middleware_name", middleware_name)
 
             # Creating spans for the "receive" callback
-            async def _sentry_receive(
-                *args: "Any", **kwargs: "Any"
-            ) -> "Union[HTTPReceiveMessage, WebSocketReceiveMessage]":
+            async def _sentry_receive(*args, **kwargs):
+                # type: (*Any, **Any) -> Union[HTTPReceiveMessage, WebSocketReceiveMessage]
+                if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
+                    return await receive(*args, **kwargs)
                 with sentry_sdk.start_span(
                     op=OP.MIDDLEWARE_STARLITE_RECEIVE,
                     description=getattr(receive, "__qualname__", str(receive)),
@@ -155,7 +160,10 @@ async def _sentry_receive(
             new_receive = _sentry_receive if not receive_patched else receive
 
             # Creating spans for the "send" callback
-            async def _sentry_send(message: "Message") -> None:
+            async def _sentry_send(message):
+                # type: (Message) -> None
+                if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
+                    return await send(message)
                 with sentry_sdk.start_span(
                     op=OP.MIDDLEWARE_STARLITE_SEND,
                     description=getattr(send, "__qualname__", str(send)),
@@ -181,19 +189,19 @@ async def _sentry_send(message: "Message") -> None:
     return middleware
 
 
-def patch_http_route_handle() -> None:
+def patch_http_route_handle():
+    # type: () -> None
     old_handle = HTTPRoute.handle
 
-    async def handle_wrapper(
-        self: "HTTPRoute", scope: "HTTPScope", receive: "Receive", send: "Send"
-    ) -> None:
+    async def handle_wrapper(self, scope, receive, send):
+        # type: (HTTPRoute, HTTPScope, Receive, Send) -> None
         if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
             return await old_handle(self, scope, receive, send)
 
         sentry_scope = sentry_sdk.get_isolation_scope()
-        request: "Request[Any, Any]" = scope["app"].request_class(
+        request = scope["app"].request_class(
             scope=scope, receive=receive, send=send
-        )
+        )  # type: Request[Any, Any]
         extracted_request_data = ConnectionDataExtractor(
             parse_body=True, parse_query=True
         )(request)
@@ -201,7 +209,8 @@ async def handle_wrapper(
 
         request_data = await body
 
-        def event_processor(event: "Event", _: "Hint") -> "Event":
+        def event_processor(event, _):
+            # type: (Event, Hint) -> Event
             route_handler = scope.get("route_handler")
 
             request_info = event.get("request", {})
@@ -244,8 +253,9 @@ def event_processor(event: "Event", _: "Hint") -> "Event":
     HTTPRoute.handle = handle_wrapper
 
 
-def retrieve_user_from_scope(scope: "StarliteScope") -> "Optional[Dict[str, Any]]":
-    scope_user = scope.get("user", {})
+def retrieve_user_from_scope(scope):
+    # type: (StarliteScope) -> Optional[dict[str, Any]]
+    scope_user = scope.get("user")
     if not scope_user:
         return None
     if isinstance(scope_user, dict):
@@ -263,8 +273,9 @@ def retrieve_user_from_scope(scope: "StarliteScope") -> "Optional[Dict[str, Any]
 
 
 @ensure_integration_enabled(StarliteIntegration)
-def exception_handler(exc: Exception, scope: "StarliteScope", _: "State") -> None:
-    user_info: "Optional[Dict[str, Any]]" = None
+def exception_handler(exc, scope, _):
+    # type: (Exception, StarliteScope, State) -> None
+    user_info = None  # type: Optional[dict[str, Any]]
     if should_send_default_pii():
         user_info = retrieve_user_from_scope(scope)
     if user_info and isinstance(user_info, dict):
diff --git a/tests/integrations/starlite/test_starlite.py b/tests/integrations/starlite/test_starlite.py
index 45075b5199..2c3aa704f5 100644
--- a/tests/integrations/starlite/test_starlite.py
+++ b/tests/integrations/starlite/test_starlite.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
 import functools
 
 import pytest
@@ -13,50 +14,6 @@
 from starlite.testing import TestClient
 
 
-class SampleMiddleware(AbstractMiddleware):
-    async def __call__(self, scope, receive, send) -> None:
-        async def do_stuff(message):
-            if message["type"] == "http.response.start":
-                # do something here.
-                pass
-            await send(message)
-
-        await self.app(scope, receive, do_stuff)
-
-
-class SampleReceiveSendMiddleware(AbstractMiddleware):
-    async def __call__(self, scope, receive, send):
-        message = await receive()
-        assert message
-        assert message["type"] == "http.request"
-
-        send_output = await send({"type": "something-unimportant"})
-        assert send_output is None
-
-        await self.app(scope, receive, send)
-
-
-class SamplePartialReceiveSendMiddleware(AbstractMiddleware):
-    async def __call__(self, scope, receive, send):
-        message = await receive()
-        assert message
-        assert message["type"] == "http.request"
-
-        send_output = await send({"type": "something-unimportant"})
-        assert send_output is None
-
-        async def my_receive(*args, **kwargs):
-            pass
-
-        async def my_send(*args, **kwargs):
-            pass
-
-        partial_receive = functools.partial(my_receive)
-        partial_send = functools.partial(my_send)
-
-        await self.app(scope, partial_receive, partial_send)
-
-
 def starlite_app_factory(middleware=None, debug=True, exception_handlers=None):
     class MyController(Controller):
         path = "/controller"
@@ -66,7 +23,7 @@ async def controller_error(self) -> None:
             raise Exception("Whoa")
 
     @get("/some_url")
-    async def homepage_handler() -> Dict[str, Any]:
+    async def homepage_handler() -> "Dict[str, Any]":
         1 / 0
         return {"status": "ok"}
 
@@ -75,12 +32,12 @@ async def custom_error() -> Any:
         raise Exception("Too Hot")
 
     @get("/message")
-    async def message() -> Dict[str, Any]:
+    async def message() -> "Dict[str, Any]":
         capture_message("hi")
         return {"status": "ok"}
 
     @get("/message/{message_id:str}")
-    async def message_with_id() -> Dict[str, Any]:
+    async def message_with_id() -> "Dict[str, Any]":
         capture_message("hi")
         return {"status": "ok"}
 
@@ -151,8 +108,8 @@ def test_catch_exceptions(
     assert str(exc) == expected_message
 
     (event,) = events
-    assert event["exception"]["values"][0]["mechanism"]["type"] == "starlite"
     assert event["transaction"] == expected_tx_name
+    assert event["exception"]["values"][0]["mechanism"]["type"] == "starlite"
 
 
 def test_middleware_spans(sentry_init, capture_events):
@@ -177,40 +134,50 @@ def test_middleware_spans(sentry_init, capture_events):
     client = TestClient(
         starlite_app, raise_server_exceptions=False, base_url="http://testserver.local"
     )
-    try:
-        client.get("/message")
-    except Exception:
-        pass
+    client.get("/message")
 
     (_, transaction_event) = events
 
-    expected = ["SessionMiddleware", "LoggingMiddleware", "RateLimitMiddleware"]
+    expected = {"SessionMiddleware", "LoggingMiddleware", "RateLimitMiddleware"}
+    found = set()
+
+    starlite_spans = (
+        span
+        for span in transaction_event["spans"]
+        if span["op"] == "middleware.starlite"
+    )
 
-    idx = 0
-    for span in transaction_event["spans"]:
-        if span["op"] == "middleware.starlite":
-            assert span["description"] == expected[idx]
-            assert span["tags"]["starlite.middleware_name"] == expected[idx]
-            idx += 1
+    for span in starlite_spans:
+        assert span["description"] in expected
+        assert span["description"] not in found
+        found.add(span["description"])
+        assert span["description"] == span["tags"]["starlite.middleware_name"]
 
 
 def test_middleware_callback_spans(sentry_init, capture_events):
+    class SampleMiddleware(AbstractMiddleware):
+        async def __call__(self, scope, receive, send) -> None:
+            async def do_stuff(message):
+                if message["type"] == "http.response.start":
+                    # do something here.
+                    pass
+                await send(message)
+
+            await self.app(scope, receive, do_stuff)
+
     sentry_init(
         traces_sample_rate=1.0,
         integrations=[StarliteIntegration()],
     )
-    starlette_app = starlite_app_factory(middleware=[SampleMiddleware])
+    starlite_app = starlite_app_factory(middleware=[SampleMiddleware])
     events = capture_events()
 
-    client = TestClient(starlette_app, raise_server_exceptions=False)
-    try:
-        client.get("/message")
-    except Exception:
-        pass
+    client = TestClient(starlite_app, raise_server_exceptions=False)
+    client.get("/message")
 
-    (_, transaction_event) = events
+    (_, transaction_events) = events
 
-    expected = [
+    expected_starlite_spans = [
         {
             "op": "middleware.starlite",
             "description": "SampleMiddleware",
@@ -227,47 +194,86 @@ def test_middleware_callback_spans(sentry_init, capture_events):
             "tags": {"starlite.middleware_name": "SampleMiddleware"},
         },
     ]
-    for idx, span in enumerate(transaction_event["spans"]):
-        assert span["op"] == expected[idx]["op"]
-        assert span["description"] == expected[idx]["description"]
-        assert span["tags"] == expected[idx]["tags"]
+
+    def is_matching_span(expected_span, actual_span):
+        return (
+            expected_span["op"] == actual_span["op"]
+            and expected_span["description"] == actual_span["description"]
+            and expected_span["tags"] == actual_span["tags"]
+        )
+
+    actual_starlite_spans = list(
+        span
+        for span in transaction_events["spans"]
+        if "middleware.starlite" in span["op"]
+    )
+    assert len(actual_starlite_spans) == 3
+
+    for expected_span in expected_starlite_spans:
+        assert any(
+            is_matching_span(expected_span, actual_span)
+            for actual_span in actual_starlite_spans
+        )
 
 
 def test_middleware_receive_send(sentry_init, capture_events):
+    class SampleReceiveSendMiddleware(AbstractMiddleware):
+        async def __call__(self, scope, receive, send):
+            message = await receive()
+            assert message
+            assert message["type"] == "http.request"
+
+            send_output = await send({"type": "something-unimportant"})
+            assert send_output is None
+
+            await self.app(scope, receive, send)
+
     sentry_init(
         traces_sample_rate=1.0,
         integrations=[StarliteIntegration()],
     )
-    starlette_app = starlite_app_factory(middleware=[SampleReceiveSendMiddleware])
+    starlite_app = starlite_app_factory(middleware=[SampleReceiveSendMiddleware])
 
-    client = TestClient(starlette_app, raise_server_exceptions=False)
-    try:
-        # NOTE: the assert statements checking
-        # for correct behaviour are in `SampleReceiveSendMiddleware`!
-        client.get("/message")
-    except Exception:
-        pass
+    client = TestClient(starlite_app, raise_server_exceptions=False)
+    # See SampleReceiveSendMiddleware.__call__ above for assertions of correct behavior
+    client.get("/message")
 
 
 def test_middleware_partial_receive_send(sentry_init, capture_events):
+    class SamplePartialReceiveSendMiddleware(AbstractMiddleware):
+        async def __call__(self, scope, receive, send):
+            message = await receive()
+            assert message
+            assert message["type"] == "http.request"
+
+            send_output = await send({"type": "something-unimportant"})
+            assert send_output is None
+
+            async def my_receive(*args, **kwargs):
+                pass
+
+            async def my_send(*args, **kwargs):
+                pass
+
+            partial_receive = functools.partial(my_receive)
+            partial_send = functools.partial(my_send)
+
+            await self.app(scope, partial_receive, partial_send)
+
     sentry_init(
         traces_sample_rate=1.0,
         integrations=[StarliteIntegration()],
     )
-    starlette_app = starlite_app_factory(
-        middleware=[SamplePartialReceiveSendMiddleware]
-    )
+    starlite_app = starlite_app_factory(middleware=[SamplePartialReceiveSendMiddleware])
     events = capture_events()
 
-    client = TestClient(starlette_app, raise_server_exceptions=False)
-    try:
-        client.get("/message")
-    except Exception:
-        pass
+    client = TestClient(starlite_app, raise_server_exceptions=False)
+    # See SamplePartialReceiveSendMiddleware.__call__ above for assertions of correct behavior
+    client.get("/message")
 
-    (_, transaction_event) = events
+    (_, transaction_events) = events
 
-    expected = [
+    expected_starlite_spans = [
         {
             "op": "middleware.starlite",
             "description": "SamplePartialReceiveSendMiddleware",
@@ -285,10 +291,25 @@ def test_middleware_partial_receive_send(sentry_init, capture_events):
         },
     ]
 
-    for idx, span in enumerate(transaction_event["spans"]):
-        assert span["op"] == expected[idx]["op"]
-        assert span["description"].startswith(expected[idx]["description"])
-        assert span["tags"] == expected[idx]["tags"]
+    def is_matching_span(expected_span, actual_span):
+        return (
+            expected_span["op"] == actual_span["op"]
+            and actual_span["description"].startswith(expected_span["description"])
+            and expected_span["tags"] == actual_span["tags"]
+        )
+
+    actual_starlite_spans = list(
+        span
+        for span in transaction_events["spans"]
+        if "middleware.starlite" in span["op"]
+    )
+    assert len(actual_starlite_spans) == 3
+
+    for expected_span in expected_starlite_spans:
+        assert any(
+            is_matching_span(expected_span, actual_span)
+            for actual_span in actual_starlite_spans
+        )
 
 
 def test_span_origin(sentry_init, capture_events):
@@ -313,13 +334,62 @@ def test_span_origin(sentry_init, capture_events):
     client = TestClient(
         starlite_app, raise_server_exceptions=False, base_url="http://testserver.local"
     )
-    try:
-        client.get("/message")
-    except Exception:
-        pass
+    client.get("/message")
 
     (_, event) = events
 
     assert event["contexts"]["trace"]["origin"] == "auto.http.starlite"
     for span in event["spans"]:
         assert span["origin"] == "auto.http.starlite"
+
+
+@pytest.mark.parametrize(
+    "is_send_default_pii",
+    [
+        True,
+        False,
+    ],
+    ids=[
+        "send_default_pii=True",
+        "send_default_pii=False",
+    ],
+)
+def test_starlite_scope_user_on_exception_event(
+    sentry_init, capture_exceptions, capture_events, is_send_default_pii
+):
+    class TestUserMiddleware(AbstractMiddleware):
+        async def __call__(self, scope, receive, send):
+            scope["user"] = {
+                "email": "lennon@thebeatles.com",
+                "username": "john",
+                "id": "1",
+            }
+            await self.app(scope, receive, send)
+
+    sentry_init(
+        integrations=[StarliteIntegration()], send_default_pii=is_send_default_pii
+    )
+    starlite_app = starlite_app_factory(middleware=[TestUserMiddleware])
+    exceptions = capture_exceptions()
+    events = capture_events()
+
+    # This request intentionally raises an exception
+    client = TestClient(starlite_app)
+    try:
+        client.get("/some_url")
+    except Exception:
+        pass
+
+    assert len(exceptions) == 1
+    assert len(events) == 1
+    (event,) = events
+
+    if is_send_default_pii:
+        assert "user" in event
+        assert event["user"] == {
+            "email": "lennon@thebeatles.com",
+            "username": "john",
+            "id": "1",
+        }
+    else:
+        assert "user" not in event