Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(integrations): Update StarliteIntegration to be more in line with new LitestarIntegration #3384

Merged
merged 9 commits into from
Aug 6, 2024
113 changes: 62 additions & 51 deletions sentry_sdk/integrations/starlite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING

import sentry_sdk
from sentry_sdk._types import TYPE_CHECKING
from sentry_sdk.consts import OP
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
Expand All @@ -20,26 +19,26 @@
from starlite.routes.http import HTTPRoute # type: ignore
from starlite.utils import ConnectionDataExtractor, is_async_callable, Ref # type: ignore
from pydantic import BaseModel # type: ignore

if TYPE_CHECKING:
from typing import Any, Dict, List, Optional, Union
from starlite.types import ( # type: ignore
ASGIApp,
Hint,
HTTPReceiveMessage,
HTTPScope,
Message,
Middleware,
Receive,
Scope as StarliteScope,
Send,
WebSocketReceiveMessage,
)
from starlite import MiddlewareProtocol
from sentry_sdk._types import Event
except ImportError:
raise DidNotEnable("Starlite is not installed")

if TYPE_CHECKING:
from typing import Any, Optional, Union
from starlite.types import ( # type: ignore
ASGIApp,
Hint,
HTTPReceiveMessage,
HTTPScope,
Message,
Middleware,
Receive,
Scope as StarliteScope,
Send,
WebSocketReceiveMessage,
)
from starlite import MiddlewareProtocol
from sentry_sdk._types import Event


_DEFAULT_TRANSACTION_NAME = "generic Starlite request"

Expand All @@ -49,14 +48,16 @@ class StarliteIntegration(Integration):
origin = f"auto.http.{identifier}"

@staticmethod
def setup_once() -> None:
def setup_once():
# type: () -> None
patch_app_init()
patch_middlewares()
patch_http_route_handle()


class SentryStarliteASGIMiddleware(SentryAsgiMiddleware):
def __init__(self, app: "ASGIApp", span_origin: str = StarliteIntegration.origin):
def __init__(self, app, span_origin=StarliteIntegration.origin):
# type: (ASGIApp, str) -> None
super().__init__(
app=app,
unsafe_context_data=False,
Expand All @@ -66,7 +67,8 @@ def __init__(self, app: "ASGIApp", span_origin: str = StarliteIntegration.origin
)


def patch_app_init() -> None:
def patch_app_init():
# type: () -> None
"""
Replaces the Starlite class's `__init__` function in order to inject `after_exception` handlers and set the
`SentryStarliteASGIMiddleware` as the outmost middleware in the stack.
Expand All @@ -76,7 +78,9 @@ def patch_app_init() -> None:
"""
old__init__ = Starlite.__init__

def injection_wrapper(self: "Starlite", *args: "Any", **kwargs: "Any") -> None:
@ensure_integration_enabled(StarliteIntegration, old__init__)
def injection_wrapper(self, *args, **kwargs):
# type: (Starlite, *Any, **Any) -> None
after_exception = kwargs.pop("after_exception", [])
kwargs.update(
after_exception=[
Expand All @@ -90,43 +94,43 @@ def injection_wrapper(self: "Starlite", *args: "Any", **kwargs: "Any") -> None:
)

SentryStarliteASGIMiddleware.__call__ = SentryStarliteASGIMiddleware._run_asgi3 # type: ignore
middleware = kwargs.pop("middleware", None) or []
middleware = kwargs.get("middleware") or []
kwargs["middleware"] = [SentryStarliteASGIMiddleware, *middleware]
antonpirker marked this conversation as resolved.
Show resolved Hide resolved
old__init__(self, *args, **kwargs)

Starlite.__init__ = injection_wrapper


def patch_middlewares() -> None:
old__resolve_middleware_stack = BaseRouteHandler.resolve_middleware
def patch_middlewares():
# type: () -> None
old_resolve_middleware_stack = BaseRouteHandler.resolve_middleware

def resolve_middleware_wrapper(self: "Any") -> "List[Middleware]":
@ensure_integration_enabled(StarliteIntegration, old_resolve_middleware_stack)
def resolve_middleware_wrapper(self):
# type: (BaseRouteHandler) -> list[Middleware]
return [
enable_span_for_middleware(middleware)
for middleware in old__resolve_middleware_stack(self)
for middleware in old_resolve_middleware_stack(self)
]

BaseRouteHandler.resolve_middleware = resolve_middleware_wrapper


def enable_span_for_middleware(middleware: "Middleware") -> "Middleware":
def enable_span_for_middleware(middleware):
# type: (Middleware) -> Middleware
if (
not hasattr(middleware, "__call__") # noqa: B004
or middleware is SentryStarliteASGIMiddleware
):
return middleware

if isinstance(middleware, DefineMiddleware):
old_call: "ASGIApp" = middleware.middleware.__call__
old_call = middleware.middleware.__call__ # type: ASGIApp
else:
old_call = middleware.__call__

async def _create_span_call(
self: "MiddlewareProtocol",
scope: "StarliteScope",
receive: "Receive",
send: "Send",
) -> None:
async def _create_span_call(self, scope, receive, send):
# type: (MiddlewareProtocol, StarliteScope, Receive, Send) -> None
if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
return await old_call(self, scope, receive, send)

Expand All @@ -139,9 +143,10 @@ async def _create_span_call(
middleware_span.set_tag("starlite.middleware_name", middleware_name)

# Creating spans for the "receive" callback
async def _sentry_receive(
*args: "Any", **kwargs: "Any"
) -> "Union[HTTPReceiveMessage, WebSocketReceiveMessage]":
async def _sentry_receive(*args, **kwargs):
# type: (*Any, **Any) -> Union[HTTPReceiveMessage, WebSocketReceiveMessage]
if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
return await receive(*args, **kwargs)
with sentry_sdk.start_span(
op=OP.MIDDLEWARE_STARLITE_RECEIVE,
description=getattr(receive, "__qualname__", str(receive)),
Expand All @@ -155,7 +160,10 @@ async def _sentry_receive(
new_receive = _sentry_receive if not receive_patched else receive

# Creating spans for the "send" callback
async def _sentry_send(message: "Message") -> None:
async def _sentry_send(message):
# type: (Message) -> None
if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
return await send(message)
with sentry_sdk.start_span(
op=OP.MIDDLEWARE_STARLITE_SEND,
description=getattr(send, "__qualname__", str(send)),
Expand All @@ -181,27 +189,28 @@ async def _sentry_send(message: "Message") -> None:
return middleware


def patch_http_route_handle() -> None:
def patch_http_route_handle():
# type: () -> None
old_handle = HTTPRoute.handle

async def handle_wrapper(
self: "HTTPRoute", scope: "HTTPScope", receive: "Receive", send: "Send"
) -> None:
async def handle_wrapper(self, scope, receive, send):
# type: (HTTPRoute, HTTPScope, Receive, Send) -> None
if sentry_sdk.get_client().get_integration(StarliteIntegration) is None:
return await old_handle(self, scope, receive, send)

sentry_scope = sentry_sdk.get_isolation_scope()
request: "Request[Any, Any]" = scope["app"].request_class(
request = scope["app"].request_class(
scope=scope, receive=receive, send=send
)
) # type: Request[Any, Any]
extracted_request_data = ConnectionDataExtractor(
parse_body=True, parse_query=True
)(request)
body = extracted_request_data.pop("body")

request_data = await body

def event_processor(event: "Event", _: "Hint") -> "Event":
def event_processor(event, _):
# type: (Event, Hint) -> Event
route_handler = scope.get("route_handler")

request_info = event.get("request", {})
Expand Down Expand Up @@ -244,8 +253,9 @@ def event_processor(event: "Event", _: "Hint") -> "Event":
HTTPRoute.handle = handle_wrapper


def retrieve_user_from_scope(scope: "StarliteScope") -> "Optional[Dict[str, Any]]":
scope_user = scope.get("user", {})
def retrieve_user_from_scope(scope):
# type: (StarliteScope) -> Optional[dict[str, Any]]
scope_user = scope.get("user")
if not scope_user:
return None
if isinstance(scope_user, dict):
Expand All @@ -263,8 +273,9 @@ def retrieve_user_from_scope(scope: "StarliteScope") -> "Optional[Dict[str, Any]


@ensure_integration_enabled(StarliteIntegration)
def exception_handler(exc: Exception, scope: "StarliteScope", _: "State") -> None:
user_info: "Optional[Dict[str, Any]]" = None
def exception_handler(exc, scope, _):
# type: (Exception, StarliteScope, State) -> None
user_info = None # type: Optional[dict[str, Any]]
if should_send_default_pii():
user_info = retrieve_user_from_scope(scope)
if user_info and isinstance(user_info, dict):
Expand Down
Loading
Loading