diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index 45a18c35246f..8ac8b1200e56 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -664,15 +664,10 @@ def initialize_run(self): self._telemetry.start_span( name=self.flow.name, run=self.flow_run, + client=self.client, parameters=self.parameters, parent_labels=parent_labels, ) - carrier = self._telemetry.propagate_traceparent() - if carrier: - self.client.update_flow_run_labels( - flow_run_id=self.flow_run.id, - labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]}, - ) try: yield self @@ -1233,18 +1228,13 @@ async def initialize_run(self): if parent_flow_run and parent_flow_run.flow_run: parent_labels = parent_flow_run.flow_run.labels - self._telemetry.start_span( + await self._telemetry.async_start_span( name=self.flow.name, run=self.flow_run, + client=self.client, parameters=self.parameters, parent_labels=parent_labels, ) - carrier = self._telemetry.propagate_traceparent() - if carrier: - await self.client.update_flow_run_labels( - flow_run_id=self.flow_run.id, - labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]}, - ) try: yield self diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index c1ce948d2b40..46c8d12a9efd 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -705,6 +705,7 @@ def initialize_run( self._telemetry.start_span( run=self.task_run, name=self.task.name, + client=self.client, parameters=self.parameters, parent_labels=parent_labels, ) @@ -1243,9 +1244,10 @@ async def initialize_run( if parent_flow_run_context and parent_flow_run_context.flow_run: parent_labels = parent_flow_run_context.flow_run.labels - self._telemetry.start_span( + await self._telemetry.async_start_span( run=self.task_run, name=self.task.name, + client=self.client, parameters=self.parameters, parent_labels=parent_labels, ) diff --git a/src/prefect/telemetry/instrumentation.py b/src/prefect/telemetry/instrumentation.py index bb1ddbfcb425..f1f458b785c1 100644 --- a/src/prefect/telemetry/instrumentation.py +++ b/src/prefect/telemetry/instrumentation.py @@ -55,7 +55,7 @@ def _url_join(base_url: str, path: str) -> str: def setup_exporters( api_url: str, api_key: str -) -> tuple[TracerProvider, MeterProvider, "LoggerProvider"]: +) -> "tuple[TracerProvider, MeterProvider, LoggerProvider]": account_id, workspace_id = extract_account_and_workspace_id(api_url) telemetry_url = _url_join(api_url, "telemetry/") diff --git a/src/prefect/telemetry/processors.py b/src/prefect/telemetry/processors.py index f5f1dc663e9c..03a33ab0f2b6 100644 --- a/src/prefect/telemetry/processors.py +++ b/src/prefect/telemetry/processors.py @@ -1,14 +1,17 @@ import time from threading import Event, Lock, Thread -from typing import Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional from opentelemetry.context import Context -from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor -from opentelemetry.sdk.trace.export import SpanExporter +from opentelemetry.sdk.trace import Span, SpanProcessor + +if TYPE_CHECKING: + from opentelemetry.sdk.trace import ReadableSpan, Span + from opentelemetry.sdk.trace.export import SpanExporter class InFlightSpanProcessor(SpanProcessor): - def __init__(self, span_exporter: SpanExporter): + def __init__(self, span_exporter: "SpanExporter"): self.span_exporter = span_exporter self._in_flight: Dict[int, Span] = {} self._lock = Lock() @@ -26,7 +29,7 @@ def _export_periodically(self) -> None: if to_export: self.span_exporter.export(to_export) - def _readable_span(self, span: Span) -> ReadableSpan: + def _readable_span(self, span: "Span") -> "ReadableSpan": readable = span._readable_span() readable._end_time = time.time_ns() readable._attributes = { @@ -35,13 +38,13 @@ def _readable_span(self, span: Span) -> ReadableSpan: } return readable - def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None: + def on_start(self, span: "Span", parent_context: Optional[Context] = None) -> None: if not span.context or not span.context.trace_flags.sampled: return with self._lock: self._in_flight[span.context.span_id] = span - def on_end(self, span: ReadableSpan) -> None: + def on_end(self, span: "ReadableSpan") -> None: if not span.context or not span.context.trace_flags.sampled: return with self._lock: diff --git a/src/prefect/telemetry/run_telemetry.py b/src/prefect/telemetry/run_telemetry.py index bb7cc81de5f9..ab7ef7aee729 100644 --- a/src/prefect/telemetry/run_telemetry.py +++ b/src/prefect/telemetry/run_telemetry.py @@ -1,8 +1,9 @@ import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from opentelemetry import propagate, trace +from opentelemetry.context import Context from opentelemetry.propagators.textmap import Setter from opentelemetry.trace import ( Span, @@ -10,8 +11,10 @@ StatusCode, get_tracer, ) +from typing_extensions import TypeAlias import prefect +from prefect.client.orchestration import PrefectClient, SyncPrefectClient from prefect.client.schemas import FlowRun, TaskRun from prefect.client.schemas.objects import State from prefect.context import FlowRunContext @@ -23,6 +26,8 @@ LABELS_TRACEPARENT_KEY = "__OTEL_TRACEPARENT" TRACEPARENT_KEY = "traceparent" +FlowOrTaskRun: TypeAlias = Union[FlowRun, TaskRun] + class OTELSetter(Setter[KeyValueLabels]): """ @@ -44,13 +49,47 @@ class RunTelemetry: ) span: Optional[Span] = None + async def async_start_span( + self, + run: FlowOrTaskRun, + client: PrefectClient, + name: Optional[str] = None, + parameters: Optional[dict[str, Any]] = None, + parent_labels: Optional[dict[str, Any]] = None, + ): + should_set_traceparent = self._should_set_traceparent(run) + traceparent, span = self._start_span(run, name, parameters, parent_labels) + + if should_set_traceparent and traceparent: + await client.update_flow_run_labels( + run.id, {LABELS_TRACEPARENT_KEY: traceparent} + ) + + return span + def start_span( self, - run: Union[TaskRun, FlowRun], + run: FlowOrTaskRun, + client: SyncPrefectClient, name: Optional[str] = None, - parameters: Optional[Dict[str, Any]] = None, - parent_labels: Optional[Dict[str, Any]] = None, + parameters: Optional[dict[str, Any]] = None, + parent_labels: Optional[dict[str, Any]] = None, ): + should_set_traceparent = self._should_set_traceparent(run) + traceparent, span = self._start_span(run, name, parameters, parent_labels) + + if should_set_traceparent and traceparent: + client.update_flow_run_labels(run.id, {LABELS_TRACEPARENT_KEY: traceparent}) + + return span + + def _start_span( + self, + run: FlowOrTaskRun, + name: Optional[str] = None, + parameters: Optional[dict[str, Any]] = None, + parent_labels: Optional[dict[str, Any]] = None, + ) -> tuple[Optional[str], Span]: """ Start a span for a task run. """ @@ -62,10 +101,15 @@ def start_span( f"prefect.run.parameter.{k}": type(v).__name__ for k, v in parameters.items() } - run_type = "task" if isinstance(run, TaskRun) else "flow" + + traceparent, context = self._traceparent_and_context_from_labels( + {**parent_labels, **run.labels} + ) + run_type = self._run_type(run) self.span = self._tracer.start_span( name=name or run.name, + context=context, attributes={ f"prefect.{run_type}.name": name or run.name, "prefect.run.type": run_type, @@ -75,7 +119,41 @@ def start_span( **parent_labels, }, ) - return self.span + + if not traceparent: + traceparent = self._traceparent_from_span(self.span) + + if traceparent and LABELS_TRACEPARENT_KEY not in run.labels: + run.labels[LABELS_TRACEPARENT_KEY] = traceparent + + return traceparent, self.span + + def _run_type(self, run: FlowOrTaskRun) -> str: + return "task" if isinstance(run, TaskRun) else "flow" + + def _should_set_traceparent(self, run: FlowOrTaskRun) -> bool: + # If the run is a flow run and it doesn't already have a traceparent, + # we need to update its labels with the traceparent so that its + # propagated to child runs. Task runs are updated via events so we + # don't need to update them via the client in the same way. + return ( + LABELS_TRACEPARENT_KEY not in run.labels and self._run_type(run) == "flow" + ) + + def _traceparent_and_context_from_labels( + self, labels: Optional[KeyValueLabels] + ) -> tuple[Optional[str], Optional[Context]]: + """Get trace context from run labels if it exists.""" + if not labels or LABELS_TRACEPARENT_KEY not in labels: + return None, None + traceparent = labels[LABELS_TRACEPARENT_KEY] + carrier = {TRACEPARENT_KEY: traceparent} + return str(traceparent), propagate.extract(carrier) + + def _traceparent_from_span(self, span: Span) -> Optional[str]: + carrier = {} + propagate.inject(carrier, context=trace.set_span_in_context(span)) + return carrier.get(TRACEPARENT_KEY) def end_span_on_success(self) -> None: """ diff --git a/tests/telemetry/test_instrumentation.py b/tests/telemetry/test_instrumentation.py index ecb7377be899..b01c18a8ae4e 100644 --- a/tests/telemetry/test_instrumentation.py +++ b/tests/telemetry/test_instrumentation.py @@ -18,16 +18,15 @@ from prefect import flow, task from prefect.client.orchestration import SyncPrefectClient from prefect.context import FlowRunContext -from prefect.task_engine import ( - run_task_async, - run_task_sync, -) +from prefect.flow_engine import run_flow_async, run_flow_sync +from prefect.task_engine import run_task_async, run_task_sync from prefect.telemetry.bootstrap import setup_telemetry from prefect.telemetry.instrumentation import ( extract_account_and_workspace_id, ) from prefect.telemetry.logging import get_log_handler from prefect.telemetry.processors import InFlightSpanProcessor +from prefect.telemetry.run_telemetry import LABELS_TRACEPARENT_KEY def test_extract_account_and_workspace_id_valid_url( @@ -181,6 +180,67 @@ async def engine_type( ) -> Literal["async", "sync"]: return request.param + async def test_traceparent_propagates_from_server_side( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + sync_prefect_client: SyncPrefectClient, + ): + """Test that when no parent traceparent exists, the flow run stores its own span's traceparent""" + + @flow + async def my_async_flow(): + pass + + @flow + def my_sync_flow(): + pass + + if engine_type == "async": + the_flow = my_async_flow + else: + the_flow = my_sync_flow + + flow_run = sync_prefect_client.create_flow_run(the_flow) # type: ignore + + # Give the flow run a traceparent. This can occur when the server has + # already created a trace for the run, likely because it was Late. + # + # Trace ID: 314419354619557650326501540139523824930 + # Span ID: 5357380918965115138 + sync_prefect_client.update_flow_run_labels( + flow_run.id, + { + LABELS_TRACEPARENT_KEY: "00-ec8af70b445d54387035c27eb182dd22-4a593d8fa95f1902-01" + }, + ) + + flow_run = sync_prefect_client.read_flow_run(flow_run.id) + assert flow_run.labels[LABELS_TRACEPARENT_KEY] == ( + "00-ec8af70b445d54387035c27eb182dd22-4a593d8fa95f1902-01" + ) + + if engine_type == "async": + await run_flow_async(the_flow, flow_run=flow_run) # type: ignore + else: + run_flow_sync(the_flow, flow_run=flow_run) # type: ignore + + assert flow_run.labels[LABELS_TRACEPARENT_KEY] == ( + "00-ec8af70b445d54387035c27eb182dd22-4a593d8fa95f1902-01" + ) + + spans = instrumentation.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + + span_context = span.get_span_context() + assert span_context is not None + assert span_context.trace_id == 314419354619557650326501540139523824930 + + assert span.parent is not None + assert span.parent.trace_id == 314419354619557650326501540139523824930 + assert span.parent.span_id == 5357380918965115138 + async def test_flow_run_creates_and_stores_otel_traceparent( self, engine_type: Literal["async", "sync"], @@ -249,20 +309,12 @@ def sync_child_flow() -> str: return "hello from child" @flow(name="parent-flow") - async def async_parent_flow() -> str: - # Set OTEL context in the parent flow's labels - flow_run = FlowRunContext.get().flow_run - mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" - flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent - return await async_child_flow() + async def async_parent_flow(): + await async_child_flow() @flow(name="parent-flow") - def sync_parent_flow() -> str: - # Set OTEL context in the parent flow's labels - flow_run = FlowRunContext.get().flow_run - mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" - flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent - return sync_child_flow() + def sync_parent_flow(): + sync_child_flow() parent_flow = async_parent_flow if engine_type == "async" else sync_parent_flow await parent_flow() if engine_type == "async" else parent_flow()