Skip to content

Commit

Permalink
feat(integrations): Update StarliteIntegration to be more in line wit…
Browse files Browse the repository at this point in the history
…h new LitestarIntegration (#3384)

The new LitestarIntegration was initially ported from the StarliteIntegration, but then had a thorough code review that resulted in use of type comments instead of type hints (the convention used throughout the repo), more concise code in several places, and additional/updated tests. This PR backports those improvements to the StarliteIntegration. See #3358.

---------

Co-authored-by: Anton Pirker <anton.pirker@sentry.io>
  • Loading branch information
2 people authored and sentrivana committed Aug 12, 2024
1 parent cbd8956 commit e40d140
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 148 deletions.
113 changes: 62 additions & 51 deletions sentry_sdk/integrations/starlite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING

import sentry_sdk
from sentry_sdk._types import TYPE_CHECKING
from sentry_sdk.consts import OP
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
Expand All @@ -20,26 +19,26 @@
from starlite.routes.http import HTTPRoute # type: ignore
from starlite.utils import ConnectionDataExtractor, is_async_callable, Ref # type: ignore
from pydantic import BaseModel # type: ignore

if TYPE_CHECKING:
from typing import Any, Dict, List, Optional, Union
from starlite.types import ( # type: ignore
ASGIApp,
Hint,
HTTPReceiveMessage,
HTTPScope,
Message,
Middleware,
Receive,
Scope as StarliteScope,
Send,
WebSocketReceiveMessage,
)
from starlite import MiddlewareProtocol
from sentry_sdk._types import Event
except ImportError:
raise DidNotEnable("Starlite is not installed")

if TYPE_CHECKING:
from typing import Any, Optional, Union
from starlite.types import ( # type: ignore
ASGIApp,
Hint,
HTTPReceiveMessage,
HTTPScope,
Message,
Middleware,
Receive,
Scope as StarliteScope,
Send,
WebSocketReceiveMessage,
)
from starlite import MiddlewareProtocol
from sentry_sdk._types import Event


_DEFAULT_TRANSACTION_NAME = "generic Starlite request"

Expand All @@ -49,14 +48,16 @@ class StarliteIntegration(Integration):
origin = f"auto.http.{identifier}"

@staticmethod
def setup_once() -> None:
def setup_once():
# type: () -> None
patch_app_init()
patch_middlewares()
patch_http_route_handle()


class SentryStarliteASGIMiddleware(SentryAsgiMiddleware):
def __init__(self, app: "ASGIApp", span_origin: str = StarliteIntegration.origin):
def __init__(self, app, span_origin=StarliteIntegration.origin):
# type: (ASGIApp, str) -> None
super().__init__(
app=app,
unsafe_context_data=False,
Expand All @@ -66,7 +67,8 @@ def __init__(self, app: "ASGIApp", span_origin: str = StarliteIntegration.origin
)


def patch_app_init() -> None:
def patch_app_init():
# type: () -> None
"""
Replaces the Starlite class's `__init__` function in order to inject `after_exception` handlers and set the
`SentryStarliteASGIMiddleware` as the outmost middleware in the stack.
Expand All @@ -76,7 +78,9 @@ def patch_app_init() -> None:
"""
old__init__ = Starlite.__init__

def injection_wrapper(self: "Starlite", *args: "Any", **kwargs: "Any") -> None:
@ensure_integration_enabled(StarliteIntegration, old__init__)
def injection_wrapper(self, *args, **kwargs):
# type: (Starlite, *Any, **Any) -> None
after_exception = kwargs.pop("after_exception", [])
kwargs.update(
after_exception=[
Expand All @@ -90,43 +94,43 @@ def injection_wrapper(self: "Starlite", *args: "Any", **kwargs: "Any") -> None:
)

SentryStarliteASGIMiddleware.__call__ = SentryStarliteASGIMiddleware._run_asgi3 # type: ignore
middleware = kwargs.pop("middleware", None) or []
middleware = kwargs.get("middleware") or []
kwargs["middleware"] = [SentryStarliteASGIMiddleware, *middleware]
old__init__(self, *args, **kwargs)

Starlite.__init__ = injection_wrapper


def patch_middlewares() -> None:
old__resolve_middleware_stack = BaseRouteHandler.resolve_middleware
def patch_middlewares():
# type: () -> None
old_resolve_middleware_stack = BaseRouteHandler.resolve_middleware

def resolve_middleware_wrapper(self: "Any") -> "List[Middleware]":
@ensure_integration_enabled(StarliteIntegration, old_resolve_middleware_stack)
def resolve_middleware_wrapper(self):
# type: (BaseRouteHandler) -> list[Middleware]
return [
enable_span_for_middleware(middleware)
for middleware in old__resolve_middleware_stack(self)
for middleware in old_resolve_middleware_stack(self)
]

BaseRouteHandler.resolve_middleware = resolve_middleware_wrapper


def enable_span_for_middleware(middleware: "Middleware") -> "Middleware":
def enable_span_for_middleware(middleware):
# type: (Middleware) -> Middleware
if (
not hasattr(middleware, "__call__") # noqa: B004
or middleware is SentryStarliteASGIMiddleware
):
return middleware

if isinstance(middleware, DefineMiddleware):
old_call: "ASGIApp" = middleware.middleware.__call__
old_call = middleware.middleware.__call__ # type: ASGIApp
else:
old_call = middleware.__call__

async def _create_span_call(
self: "MiddlewareProtocol",
scope: "StarliteScope",
receive: "Receive",
send: "Send",
) -> None:
async def _create_span_call(self, scope, receive, send):
# type: (MiddlewareProtocol, StarliteScope, Receive, Send) -> None
if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
return await old_call(self, scope, receive, send)

Expand All @@ -139,9 +143,10 @@ async def _create_span_call(
middleware_span.set_tag("starlite.middleware_name", middleware_name)

# Creating spans for the "receive" callback
async def _sentry_receive(
*args: "Any", **kwargs: "Any"
) -> "Union[HTTPReceiveMessage, WebSocketReceiveMessage]":
async def _sentry_receive(*args, **kwargs):
# type: (*Any, **Any) -> Union[HTTPReceiveMessage, WebSocketReceiveMessage]
if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
return await receive(*args, **kwargs)
with sentry_sdk.start_span(
op=OP.MIDDLEWARE_STARLITE_RECEIVE,
description=getattr(receive, "__qualname__", str(receive)),
Expand All @@ -155,7 +160,10 @@ async def _sentry_receive(
new_receive = _sentry_receive if not receive_patched else receive

# Creating spans for the "send" callback
async def _sentry_send(message: "Message") -> None:
async def _sentry_send(message):
# type: (Message) -> None
if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
return await send(message)
with sentry_sdk.start_span(
op=OP.MIDDLEWARE_STARLITE_SEND,
description=getattr(send, "__qualname__", str(send)),
Expand All @@ -181,27 +189,28 @@ async def _sentry_send(message: "Message") -> None:
return middleware


def patch_http_route_handle() -> None:
def patch_http_route_handle():
# type: () -> None
old_handle = HTTPRoute.handle

async def handle_wrapper(
self: "HTTPRoute", scope: "HTTPScope", receive: "Receive", send: "Send"
) -> None:
async def handle_wrapper(self, scope, receive, send):
# type: (HTTPRoute, HTTPScope, Receive, Send) -> None
if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
return await old_handle(self, scope, receive, send)

sentry_scope = sentry_sdk.get_isolation_scope()
request: "Request[Any, Any]" = scope["app"].request_class(
request = scope["app"].request_class(
scope=scope, receive=receive, send=send
)
) # type: Request[Any, Any]
extracted_request_data = ConnectionDataExtractor(
parse_body=True, parse_query=True
)(request)
body = extracted_request_data.pop("body")

request_data = await body

def event_processor(event: "Event", _: "Hint") -> "Event":
def event_processor(event, _):
# type: (Event, Hint) -> Event
route_handler = scope.get("route_handler")

request_info = event.get("request", {})
Expand Down Expand Up @@ -244,8 +253,9 @@ def event_processor(event: "Event", _: "Hint") -> "Event":
HTTPRoute.handle = handle_wrapper


def retrieve_user_from_scope(scope: "StarliteScope") -> "Optional[Dict[str, Any]]":
scope_user = scope.get("user", {})
def retrieve_user_from_scope(scope):
# type: (StarliteScope) -> Optional[dict[str, Any]]
scope_user = scope.get("user")
if not scope_user:
return None
if isinstance(scope_user, dict):
Expand All @@ -263,8 +273,9 @@ def retrieve_user_from_scope(scope: "StarliteScope") -> "Optional[Dict[str, Any]


@ensure_integration_enabled(StarliteIntegration)
def exception_handler(exc: Exception, scope: "StarliteScope", _: "State") -> None:
user_info: "Optional[Dict[str, Any]]" = None
def exception_handler(exc, scope, _):
# type: (Exception, StarliteScope, State) -> None
user_info = None # type: Optional[dict[str, Any]]
if should_send_default_pii():
user_info = retrieve_user_from_scope(scope)
if user_info and isinstance(user_info, dict):
Expand Down
Loading

0 comments on commit e40d140

Please sign in to comment.