Skip to content

Commit

Permalink
fix: get_current_span should return None when llama-index is not inst…
Browse files Browse the repository at this point in the history
…rumented (#1169)
  • Loading branch information
RogerHYang authored Dec 12, 2024
1 parent b46931c commit 12d64bc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,17 @@ def _uninstrument(self, **kwargs: Any) -> None:

def get_current_span() -> Optional[Span]:
from llama_index.core.instrumentation.span import active_span_id
from openinference.instrumentation.llama_index._handler import _SpanHandler

if not isinstance(id_ := active_span_id.get(), str):
return None
if (span := LlamaIndexInstrumentor()._span_handler.open_spans.get(id_)) is None:
instrumentor = LlamaIndexInstrumentor()
try:
span_handler = instrumentor._span_handler
except AttributeError:
return None
if not isinstance(span_handler, _SpanHandler):
return None
if (span := span_handler.open_spans.get(id_)) is None:
return None
return span._otel_span
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from asyncio import create_task, gather, sleep
from random import random
from typing import Iterator

import pytest
from llama_index.core.instrumentation import get_dispatcher
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
Expand All @@ -14,7 +12,7 @@


@dispatcher.span # type: ignore[misc,unused-ignore]
async def foo(k: int) -> str:
async def foo(k: int = 1) -> str:
child = create_task(foo(k - 1)) if k > 1 else None
await sleep(random() / 100)
span = get_current_span()
Expand All @@ -24,10 +22,14 @@ async def foo(k: int) -> str:


async def test_get_current_span(
tracer_provider: TracerProvider,
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
assert await foo() == ""
n, k = 10, 5
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
await gather(*(foo(k) for _ in range(n)))
LlamaIndexInstrumentor().uninstrument()
spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == n * k
seen = set()
Expand All @@ -38,14 +40,4 @@ async def test_get_current_span(
assert span.attributes.get(OUTPUT_VALUE) == expected


@pytest.fixture(autouse=True)
def instrument(
tracer_provider: TracerProvider,
in_memory_span_exporter: InMemorySpanExporter,
) -> Iterator[None]:
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
yield
LlamaIndexInstrumentor().uninstrument()


OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE

0 comments on commit 12d64bc

Please sign in to comment.