Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add get_chain_root_span utility for langchain instrumentation #1054

Merged
merged 31 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c200b4e
Use ContextVars to track langchain root spans
anticorrelator Oct 7, 2024
a6d6702
Do not explicitly manage a stack
anticorrelator Oct 7, 2024
1f43388
Ruff 🐶
anticorrelator Oct 7, 2024
9734103
Use separate reset token store
anticorrelator Oct 7, 2024
298318a
Use better contextvar type annotation
anticorrelator Oct 7, 2024
29263e1
Flesh out type annotation for Token object
anticorrelator Oct 7, 2024
1af67a4
Add langchain instrumentor tests
anticorrelator Oct 7, 2024
89a0d45
Use better type annotation
anticorrelator Oct 7, 2024
a0b14b4
Refactor tests
anticorrelator Oct 8, 2024
3ddc798
Refactor root chain span propagation
anticorrelator Oct 8, 2024
50ffb5a
Ruff 🐶
anticorrelator Oct 8, 2024
45d1a0e
Update tests
anticorrelator Oct 9, 2024
39bd069
Fix type annotations
anticorrelator Oct 9, 2024
141ee9f
Remove unused type: ignore
anticorrelator Oct 9, 2024
6c02b33
Simplify test to not use sequences
anticorrelator Oct 9, 2024
965ce0f
Explicitly define type annotations for RunnableLambda
anticorrelator Oct 9, 2024
e1609ac
Use a RunnableSequence for more robust testing
anticorrelator Oct 10, 2024
aea1ae7
Remove root span tracking from Span attributes
anticorrelator Oct 10, 2024
71a77b7
Track span tree manually
anticorrelator Oct 10, 2024
dd41d7c
Remove references to extra span attribute
anticorrelator Oct 10, 2024
73a40b8
Remove redundant logic
anticorrelator Oct 15, 2024
5880442
Properly test concurrency
anticorrelator Oct 15, 2024
ca6d3ba
Remove unused variable
anticorrelator Oct 15, 2024
30ddc8f
Test root chain tree walking
anticorrelator Oct 16, 2024
5bcc238
Get all ancestors
anticorrelator Oct 16, 2024
611ed88
Only use run map
anticorrelator Oct 16, 2024
7a35bec
Remove old bookkeeping
anticorrelator Oct 16, 2024
6b82c94
Fix type annotations
anticorrelator Oct 17, 2024
a3e45e3
Add docstring
anticorrelator Oct 17, 2024
26dbe41
Ignore unused ignores
anticorrelator Oct 17, 2024
4a9e7b2
Update return types for consistency
anticorrelator Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore
from opentelemetry.trace import Span
from opentelemetry.sdk.trace import Span
from wrapt import wrap_function_wrapper # type: ignore

from openinference.instrumentation import OITracer, TraceConfig
Expand Down Expand Up @@ -104,3 +104,18 @@ def get_current_span() -> Optional[Span]:
if not run_id:
return None
return LangChainInstrumentor().get_span(run_id)


def get_current_chain_root_span() -> Optional[Span]:
from openinference.instrumentation.langchain._tracer import IS_CHAIN_SPAN, _spans_by_span_id

span = get_current_span()
while span and span.get_span_context().is_valid: # type: ignore[no-untyped-call]
if span.attributes and span.attributes.get(IS_CHAIN_SPAN):
return span
# Get parent span ID
parent_span = span.parent
if not parent_span or not parent_span.span_id:
break
span = _spans_by_span_id.get(parent_span.span_id)
return None
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from opentelemetry import context as context_api
from opentelemetry import trace as trace_api
from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.sdk.trace import Span
anticorrelator marked this conversation as resolved.
Show resolved Hide resolved
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.util.types import AttributeValue
from wrapt import ObjectProxy
Expand All @@ -56,6 +57,7 @@
logger.addHandler(logging.NullHandler())

_AUDIT_TIMING = False
IS_CHAIN_SPAN = "is_chain_span"
anticorrelator marked this conversation as resolved.
Show resolved Hide resolved


@wrapt.decorator # type: ignore
Expand Down Expand Up @@ -104,6 +106,9 @@ def __delitem__(self, key: K) -> None:
super().__delitem__(key)


_spans_by_span_id: Dict[int, Span] = _DictWithLock[int, Span]()
anticorrelator marked this conversation as resolved.
Show resolved Hide resolved


class OpenInferenceTracer(BaseTracer):
__slots__ = ("_tracer", "_spans_by_run")

Expand All @@ -114,10 +119,10 @@ def __init__(self, tracer: trace_api.Tracer, *args: Any, **kwargs: Any) -> None:
assert self.run_map
self.run_map = _DictWithLock[str, Run](self.run_map)
self._tracer = tracer
self._spans_by_run: Dict[UUID, trace_api.Span] = _DictWithLock[UUID, trace_api.Span]()
self._spans_by_run: Dict[UUID, Span] = _DictWithLock[UUID, Span]()
self._lock = RLock() # handlers may be run in a thread by langchain

def get_span(self, run_id: UUID) -> Optional[trace_api.Span]:
def get_span(self, run_id: UUID) -> Optional[Span]:
return self._spans_by_run.get(run_id)

@audit_timing # type: ignore
Expand All @@ -140,6 +145,10 @@ def _start_trace(self, run: Run) -> None:
context=parent_context,
start_time=start_time_utc_nano,
)

if run.run_type.lower() == "chain" and span:
span.set_attribute(IS_CHAIN_SPAN, True)
anticorrelator marked this conversation as resolved.
Show resolved Hide resolved

# The following line of code is commented out to serve as a reminder that in a system
# of callbacks, attaching the context can be hazardous because there is no guarantee
# that the context will be detached. An error could happen between callbacks leaving
Expand All @@ -148,7 +157,9 @@ def _start_trace(self, run: Run) -> None:
# leaving all future spans as orphans. That is a very bad scenario.
# token = context_api.attach(context)
with self._lock:
span = cast(Span, span)
self._spans_by_run[run.id] = span
_spans_by_span_id[span.get_span_context().span_id] = span # type: ignore[no-untyped-call]

@audit_timing # type: ignore
def _end_trace(self, run: Run) -> None:
Expand All @@ -157,6 +168,7 @@ def _end_trace(self, run: Run) -> None:
return
span = self._spans_by_run.pop(run.id, None)
if span:
_spans_by_span_id.pop(span.get_span_context().span_id, None) # type: ignore[no-untyped-call]
try:
_update_span(span, run)
except Exception:
Expand Down Expand Up @@ -207,7 +219,7 @@ def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:


@audit_timing # type: ignore
def _record_exception(span: trace_api.Span, error: BaseException) -> None:
def _record_exception(span: Span, error: BaseException) -> None:
if isinstance(error, Exception):
span.record_exception(error)
return
Expand All @@ -229,7 +241,7 @@ def _record_exception(span: trace_api.Span, error: BaseException) -> None:


@audit_timing # type: ignore
def _update_span(span: trace_api.Span, run: Run) -> None:
def _update_span(span: Span, run: Run) -> None:
if run.error is None:
span.set_status(trace_api.StatusCode.OK)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
from respx import MockRouter

from openinference.instrumentation import using_attributes
from openinference.instrumentation.langchain import get_current_span
from openinference.instrumentation.langchain import (
get_current_chain_root_span,
get_current_span,
)
from openinference.instrumentation.langchain._tracer import IS_CHAIN_SPAN
from openinference.semconv.trace import (
DocumentAttributes,
EmbeddingAttributes,
Expand Down Expand Up @@ -86,12 +90,68 @@ async def f(_: Any) -> Optional[Span]:
)
spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == n
assert {id(span.get_span_context()) for span in results if isinstance(span, Span)} == {
assert {id(span.get_span_context()) for span in results if isinstance(span, Span)} == { # type: ignore[no-untyped-call]
id(span.get_span_context()) # type: ignore[no-untyped-call]
for span in spans
}


async def test_get_current_chain_root_span(
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
"""Test retrieving the current chain root span during RunnableLambda execution."""
n = 10 # Number of concurrent runs
loop = asyncio.get_running_loop()

root_spans_during_execution = []

def f(x: int) -> int:
root_span = get_current_chain_root_span()
assert root_span is not None, "Root span should not be None during execution (sync)"
root_spans_during_execution.append(root_span)
return x + 1

with ThreadPoolExecutor() as executor:
tasks = [loop.run_in_executor(executor, RunnableLambda(f).invoke, 1) for _ in range(n)]
await asyncio.gather(*tasks)

root_span_after_execution = get_current_chain_root_span()
assert root_span_after_execution is None, "Root span should be None after execution"

assert len(root_spans_during_execution) == n, "Did not capture all root spans during execution"

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == n, f"Expected {n} spans, but found {len(spans)}"
anticorrelator marked this conversation as resolved.
Show resolved Hide resolved


async def test_get_current_chain_root_span_async(
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
"""Test retrieving the current chain root span during RunnableLambda execution."""
if sys.version_info < (3, 11):
pytest.xfail("Async test may fail on Python versions below 3.11")
n = 10 # Number of concurrent runs

root_spans_during_execution = []

async def f(x: int) -> int:
root_span = get_current_chain_root_span()
assert root_span is not None, "Root span should not be None during execution (async)"
root_spans_during_execution.append(root_span)
return x + 1

for _ in range(n):
await RunnableLambda[int, int](f).ainvoke(1)

root_span_after_execution = get_current_chain_root_span()
assert root_span_after_execution is None, "Root span should be None after execution"

assert len(root_spans_during_execution) == n, "Did not capture all root spans during execution"

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == n, f"Expected {n} spans, but found {len(spans)}"


@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize("is_stream", [False, True])
@pytest.mark.parametrize("status_code", [200, 400])
Expand Down Expand Up @@ -201,6 +261,8 @@ def main() -> None:

# Ignore metadata since LC adds a bunch of unstable metadata
rqa_attributes.pop(METADATA, None)
# Ignore IS_CHAIN_SPAN attribute since it's set internally by the instrumentation
rqa_attributes.pop(IS_CHAIN_SPAN, None)
assert rqa_attributes == {}

assert (sd_span := spans_by_name.pop("StuffDocumentsChain")) is not None
Expand All @@ -225,6 +287,8 @@ def main() -> None:

# Ignore metadata since LC adds a bunch of unstable metadata
sd_attributes.pop(METADATA, None)
# Ignore IS_CHAIN_SPAN attribute since it's set internally by the instrumentation
sd_attributes.pop(IS_CHAIN_SPAN, None)
assert sd_attributes == {}

if LANGCHAIN_VERSION >= (0, 3, 0):
Expand All @@ -251,6 +315,8 @@ def main() -> None:

# Ignore metadata since LC adds a bunch of unstable metadata
retriever_attributes.pop(METADATA, None)
# Ignore IS_CHAIN_SPAN attribute since it's set internally by the instrumentation
retriever_attributes.pop(IS_CHAIN_SPAN, None)
assert retriever_attributes == {}

assert (llm_span := spans_by_name.pop("LLMChain", None)) is not None
Expand Down Expand Up @@ -293,6 +359,8 @@ def main() -> None:

# Ignore metadata since LC adds a bunch of unstable metadata
llm_attributes.pop(METADATA, None)
# Ignore IS_CHAIN_SPAN attribute since it's set internally by the instrumentation
llm_attributes.pop(IS_CHAIN_SPAN, None)
assert llm_attributes == {}

assert (oai_span := spans_by_name.pop("ChatOpenAI", None)) is not None
Expand Down Expand Up @@ -354,6 +422,8 @@ def main() -> None:
}
# Ignore metadata since LC adds a bunch of unstable metadata
oai_attributes.pop(METADATA, None)
# Ignore IS_CHAIN_SPAN attribute since it's set internally by the instrumentation
oai_attributes.pop(IS_CHAIN_SPAN, None)
assert oai_attributes == {}

assert spans_by_name == {}
Expand Down Expand Up @@ -487,6 +557,8 @@ def test_chain_metadata(

# Ignore metadata since LC adds a bunch of unstable metadata
llm_attributes.pop(METADATA, None)
# Ignore IS_CHAIN_SPAN attribute since it's set internally by the instrumentation
llm_attributes.pop(IS_CHAIN_SPAN, None)
assert llm_attributes == {}


Expand Down Expand Up @@ -609,6 +681,8 @@ def test_read_session_from_metadata(
)
assert llm_attributes.pop(INPUT_VALUE, None) == langchain_prompt_variables["adjective"]
assert llm_attributes.pop(OUTPUT_VALUE, None) == output_val
# Ignore IS_CHAIN_SPAN attribute since it's set internally by the instrumentation
llm_attributes.pop(IS_CHAIN_SPAN, None)
assert llm_attributes == {}


Expand Down
Loading