From 85c697a2322f67e41c2fbf80b91a7ce0ffef86b4 Mon Sep 17 00:00:00 2001 From: Simon Gurcke Date: Mon, 4 Nov 2024 15:06:50 +1000 Subject: [PATCH] Fix handling of errors in background tasks in Starlette middleware --- apitally/starlette.py | 20 +++++++++----------- tests/test_starlette.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/apitally/starlette.py b/apitally/starlette.py index b7aa8bc..54057fd 100644 --- a/apitally/starlette.py +++ b/apitally/starlette.py @@ -59,11 +59,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http" and scope["method"] != "OPTIONS": request = Request(scope) response_status = 0 - response_time = 0.0 + response_time: Optional[float] = None response_headers = Headers() response_body = b"" response_size = 0 response_chunked = False + exception: Optional[BaseException] = None start_time = time.perf_counter() async def send_wrapper(message: Message) -> None: @@ -92,17 +93,11 @@ async def send_wrapper(message: Message) -> None: try: await self.app(scope, receive, send_wrapper) except BaseException as e: - self.add_request( - request=request, - response_status=500, - response_time=time.perf_counter() - start_time, - response_headers=response_headers, - response_body=response_body, - response_size=response_size, - exception=e, - ) + exception = e raise e from None - else: + finally: + if response_time is None: + response_time = time.perf_counter() - start_time self.add_request( request=request, response_status=response_status, @@ -110,6 +105,7 @@ async def send_wrapper(message: Message) -> None: response_headers=response_headers, response_body=response_body, response_size=response_size, + exception=exception, ) else: await self.app(scope, receive, send) # pragma: no cover @@ -129,6 +125,8 @@ def add_request( consumer = self.get_consumer(request) consumer_identifier = consumer.identifier if consumer else None self.client.consumer_registry.add_or_update_consumer(consumer) + if response_status == 0 and exception is not None: + response_status = 500 self.client.request_counter.add_request( consumer=consumer_identifier, method=request.method, diff --git a/tests/test_starlette.py b/tests/test_starlette.py index 72e8033..34692d3 100644 --- a/tests/test_starlette.py +++ b/tests/test_starlette.py @@ -12,6 +12,9 @@ if find_spec("starlette") is None: pytest.skip("starlette is not available", allow_module_level=True) +else: + # Need to import BackgroundTasks at package level to avoid NameError in FastAPI + from starlette.background import BackgroundTasks if TYPE_CHECKING: from starlette.applications import Starlette @@ -64,6 +67,14 @@ def stream_response(): return StreamingResponse(stream_response()) + def task(request: Request): + def task_func_with_error(): + raise ValueError("task") + + tasks = BackgroundTasks() + tasks.add_task(task_func_with_error) + return PlainTextResponse("ok", background=tasks) + routes = [ Route("/foo/", foo), Route("/foo/{bar}/", foo_bar), @@ -71,6 +82,7 @@ def stream_response(): Route("/baz/", baz, methods=["POST"]), Route("/val/", val), Route("/stream/", stream), + Route("/task/", task, methods=["POST"]), ] app = Starlette(routes=routes) app.add_middleware(ApitallyMiddleware, client_id=CLIENT_ID, env=ENV) @@ -117,6 +129,14 @@ def stream_response(): return StreamingResponse(stream_response()) + @app.post("/task/") + def task(background_tasks: BackgroundTasks): + def task_func_with_error(): + raise ValueError("task") + + background_tasks.add_task(task_func_with_error) + return "ok" + return app @@ -176,6 +196,14 @@ def test_middleware_requests_error(app: Starlette, mocker: MockerFixture): exception = mock2.call_args.kwargs["exception"] assert isinstance(exception, ValueError) + # Throws a ValueError in a background task, but returns 200 + response = client.post("/task/") + assert response.status_code == 200 + assert mock1.call_count == 2 + assert mock1.call_args is not None + assert mock1.call_args.kwargs["status_code"] == 200 + mock2.assert_called_once() # Not called again + def test_middleware_requests_unhandled(app: Starlette, mocker: MockerFixture): from starlette.testclient import TestClient @@ -216,7 +244,7 @@ def test_get_startup_data(app: Starlette, mocker: MockerFixture): app.middleware_stack = app.build_middleware_stack() data = _get_startup_data(app=app.middleware_stack, app_version="1.2.3", openapi_url=None) - assert len(data["paths"]) == 6 + assert len(data["paths"]) == 7 assert data["versions"]["starlette"] assert data["versions"]["app"] == "1.2.3" assert data["client"] == "python:starlette"