Skip to content

Commit

Permalink
Exclude background task execution from root server span in ASGI middl…
Browse files Browse the repository at this point in the history
…eware (#1952)
  • Loading branch information
siminn-arnorgj authored Nov 8, 2023
1 parent 3b9d626 commit 46fc3ce
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#1824](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1824))
- Fix sqlalchemy instrumentation wrap methods to accept sqlcommenter options
([#1873](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1873))
- Exclude background task execution from root server span in ASGI middleware
([#1952](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1952))

### Added

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ async def __call__(self, scope, receive, send):
if scope["type"] == "http":
self.active_requests_counter.add(1, active_requests_count_attrs)
try:
with trace.use_span(span, end_on_exit=True) as current_span:
with trace.use_span(span, end_on_exit=False) as current_span:
if current_span.is_recording():
for key, value in attributes.items():
current_span.set_attribute(key, value)
Expand Down Expand Up @@ -630,6 +630,8 @@ async def __call__(self, scope, receive, send):
)
if token:
context.detach(token)
if span.is_recording():
span.end()

# pylint: enable=too-many-branches

Expand All @@ -653,8 +655,11 @@ async def otel_receive():
def _get_otel_send(
self, server_span, server_span_name, scope, send, duration_attrs
):
expecting_trailers = False

@wraps(send)
async def otel_send(message):
nonlocal expecting_trailers
with self.tracer.start_as_current_span(
" ".join((server_span_name, scope["type"], "send"))
) as send_span:
Expand All @@ -668,6 +673,8 @@ async def otel_send(message):
] = status_code
set_status_code(server_span, status_code)
set_status_code(send_span, status_code)

expecting_trailers = message.get("trailers", False)
elif message["type"] == "websocket.send":
set_status_code(server_span, 200)
set_status_code(send_span, 200)
Expand Down Expand Up @@ -703,5 +710,15 @@ async def otel_send(message):
pass

await send(message)
if (
not expecting_trailers
and message["type"] == "http.response.body"
and not message.get("more_body", False)
) or (
expecting_trailers
and message["type"] == "http.response.trailers"
and not message.get("more_trailers", False)
):
server_span.end()

return otel_send
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
import sys
import time
import unittest
from timeit import default_timer
from unittest import mock
Expand Down Expand Up @@ -57,6 +58,8 @@
"http.server.request.size": _duration_attrs,
}

_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S = 0.01


async def http_app(scope, receive, send):
message = await receive()
Expand Down Expand Up @@ -99,6 +102,108 @@ async def simple_asgi(scope, receive, send):
await websocket_app(scope, receive, send)


async def long_response_asgi(scope, receive, send):
assert isinstance(scope, dict)
assert scope["type"] == "http"
message = await receive()
scope["headers"] = [(b"content-length", b"128")]
assert scope["type"] == "http"
if message.get("type") == "http.request":
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
[b"Content-Type", b"text/plain"],
[b"content-length", b"1024"],
],
}
)
await send(
{"type": "http.response.body", "body": b"*", "more_body": True}
)
await send(
{"type": "http.response.body", "body": b"*", "more_body": True}
)
await send(
{"type": "http.response.body", "body": b"*", "more_body": True}
)
await send(
{"type": "http.response.body", "body": b"*", "more_body": False}
)


async def background_execution_asgi(scope, receive, send):
assert isinstance(scope, dict)
assert scope["type"] == "http"
message = await receive()
scope["headers"] = [(b"content-length", b"128")]
assert scope["type"] == "http"
if message.get("type") == "http.request":
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
[b"Content-Type", b"text/plain"],
[b"content-length", b"1024"],
],
}
)
await send(
{
"type": "http.response.body",
"body": b"*",
}
)
time.sleep(_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S)


async def background_execution_trailers_asgi(scope, receive, send):
assert isinstance(scope, dict)
assert scope["type"] == "http"
message = await receive()
scope["headers"] = [(b"content-length", b"128")]
assert scope["type"] == "http"
if message.get("type") == "http.request":
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
[b"Content-Type", b"text/plain"],
[b"content-length", b"1024"],
],
"trailers": True,
}
)
await send(
{"type": "http.response.body", "body": b"*", "more_body": True}
)
await send(
{"type": "http.response.body", "body": b"*", "more_body": False}
)
await send(
{
"type": "http.response.trailers",
"headers": [
[b"trailer", b"test-trailer"],
],
"more_trailers": True,
}
)
await send(
{
"type": "http.response.trailers",
"headers": [
[b"trailer", b"second-test-trailer"],
],
"more_trailers": False,
}
)
time.sleep(_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S)


async def error_asgi(scope, receive, send):
assert isinstance(scope, dict)
assert scope["type"] == "http"
Expand Down Expand Up @@ -127,14 +232,19 @@ def validate_outputs(self, outputs, error=None, modifiers=None):
# Ensure modifiers is a list
modifiers = modifiers or []
# Check for expected outputs
self.assertEqual(len(outputs), 2)
response_start = outputs[0]
response_body = outputs[1]
response_final_body = [
output
for output in outputs
if output["type"] == "http.response.body"
][-1]

self.assertEqual(response_start["type"], "http.response.start")
self.assertEqual(response_body["type"], "http.response.body")
self.assertEqual(response_final_body["type"], "http.response.body")
self.assertEqual(response_final_body.get("more_body", False), False)

# Check http response body
self.assertEqual(response_body["body"], b"*")
self.assertEqual(response_final_body["body"], b"*")

# Check http response start
self.assertEqual(response_start["status"], 200)
Expand All @@ -153,7 +263,6 @@ def validate_outputs(self, outputs, error=None, modifiers=None):

# Check spans
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 4)
expected = [
{
"name": "GET / http receive",
Expand Down Expand Up @@ -194,6 +303,7 @@ def validate_outputs(self, outputs, error=None, modifiers=None):
for modifier in modifiers:
expected = modifier(expected)
# Check that output matches
self.assertEqual(len(span_list), len(expected))
for span, expected in zip(span_list, expected):
self.assertEqual(span.name, expected["name"])
self.assertEqual(span.kind, expected["kind"])
Expand Down Expand Up @@ -232,6 +342,80 @@ def test_asgi_exc_info(self):
outputs = self.get_all_output()
self.validate_outputs(outputs, error=ValueError)

def test_long_response(self):
"""Test that the server span is ended on the final response body message.
If the server span is ended early then this test will fail due
to discrepancies in the expected list of spans and the emitted list of spans.
"""
app = otel_asgi.OpenTelemetryMiddleware(long_response_asgi)
self.seed_app(app)
self.send_default_request()
outputs = self.get_all_output()

def add_more_body_spans(expected: list):
more_body_span = {
"name": "GET / http send",
"kind": trace_api.SpanKind.INTERNAL,
"attributes": {"type": "http.response.body"},
}
extra_spans = [more_body_span] * 3
expected[2:2] = extra_spans
return expected

self.validate_outputs(outputs, modifiers=[add_more_body_spans])

def test_background_execution(self):
"""Test that the server span is ended BEFORE the background task is finished."""
app = otel_asgi.OpenTelemetryMiddleware(background_execution_asgi)
self.seed_app(app)
self.send_default_request()
outputs = self.get_all_output()
self.validate_outputs(outputs)
span_list = self.memory_exporter.get_finished_spans()
server_span = span_list[-1]
assert server_span.kind == SpanKind.SERVER
span_duration_nanos = server_span.end_time - server_span.start_time
self.assertLessEqual(
span_duration_nanos,
_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S * 10**9,
)

def test_trailers(self):
"""Test that trailers are emitted as expected and that the server span is ended
BEFORE the background task is finished."""
app = otel_asgi.OpenTelemetryMiddleware(
background_execution_trailers_asgi
)
self.seed_app(app)
self.send_default_request()
outputs = self.get_all_output()

def add_body_and_trailer_span(expected: list):
body_span = {
"name": "GET / http send",
"kind": trace_api.SpanKind.INTERNAL,
"attributes": {"type": "http.response.body"},
}
trailer_span = {
"name": "GET / http send",
"kind": trace_api.SpanKind.INTERNAL,
"attributes": {"type": "http.response.trailers"},
}
expected[2:2] = [body_span]
expected[4:4] = [trailer_span] * 2
return expected

self.validate_outputs(outputs, modifiers=[add_body_and_trailer_span])
span_list = self.memory_exporter.get_finished_spans()
server_span = span_list[-1]
assert server_span.kind == SpanKind.SERVER
span_duration_nanos = server_span.end_time - server_span.start_time
self.assertLessEqual(
span_duration_nanos,
_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S * 10**9,
)

def test_override_span_name(self):
"""Test that default span_names can be overwritten by our callback function."""
span_name = "Dymaxion"
Expand Down

0 comments on commit 46fc3ce

Please sign in to comment.