Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add get_chain_root_span utility for langchain instrumentation #1054

Merged
merged 31 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c200b4e
Use ContextVars to track langchain root spans
anticorrelator Oct 7, 2024
a6d6702
Do not explicitly manage a stack
anticorrelator Oct 7, 2024
1f43388
Ruff 🐶
anticorrelator Oct 7, 2024
9734103
Use separate reset token store
anticorrelator Oct 7, 2024
298318a
Use better contextvar type annotation
anticorrelator Oct 7, 2024
29263e1
Flesh out type annotation for Token object
anticorrelator Oct 7, 2024
1af67a4
Add langchain instrumentor tests
anticorrelator Oct 7, 2024
89a0d45
Use better type annotation
anticorrelator Oct 7, 2024
a0b14b4
Refactor tests
anticorrelator Oct 8, 2024
3ddc798
Refactor root chain span propagation
anticorrelator Oct 8, 2024
50ffb5a
Ruff 🐶
anticorrelator Oct 8, 2024
45d1a0e
Update tests
anticorrelator Oct 9, 2024
39bd069
Fix type annotations
anticorrelator Oct 9, 2024
141ee9f
Remove unused type: ignore
anticorrelator Oct 9, 2024
6c02b33
Simplify test to not use sequences
anticorrelator Oct 9, 2024
965ce0f
Explicitly define type annotations for RunnableLambda
anticorrelator Oct 9, 2024
e1609ac
Use a RunnableSequence for more robust testing
anticorrelator Oct 10, 2024
aea1ae7
Remove root span tracking from Span attributes
anticorrelator Oct 10, 2024
71a77b7
Track span tree manually
anticorrelator Oct 10, 2024
dd41d7c
Remove references to extra span attribute
anticorrelator Oct 10, 2024
73a40b8
Remove redundant logic
anticorrelator Oct 15, 2024
5880442
Properly test concurrency
anticorrelator Oct 15, 2024
ca6d3ba
Remove unused variable
anticorrelator Oct 15, 2024
30ddc8f
Test root chain tree walking
anticorrelator Oct 16, 2024
5bcc238
Get all ancestors
anticorrelator Oct 16, 2024
611ed88
Only use run map
anticorrelator Oct 16, 2024
7a35bec
Remove old bookkeeping
anticorrelator Oct 16, 2024
6b82c94
Fix type annotations
anticorrelator Oct 17, 2024
a3e45e3
Add docstring
anticorrelator Oct 17, 2024
26dbe41
Ignore unused ignores
anticorrelator Oct 17, 2024
4a9e7b2
Update return types for consistency
anticorrelator Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING, Any, Callable, Collection, Optional
from typing import TYPE_CHECKING, Any, Callable, Collection, List, Optional
from uuid import UUID

from opentelemetry import trace as trace_api
Expand Down Expand Up @@ -64,6 +64,26 @@ def _uninstrument(self, **kwargs: Any) -> None:
def get_span(self, run_id: UUID) -> Optional[Span]:
return self._tracer.get_span(run_id) if self._tracer else None

def get_ancestors(self, run_id: UUID) -> Optional[List[Span]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_ancestors(self, run_id: UUID) -> Optional[List[Span]]:
def get_ancestor_spans(self, run_id: UUID) -> Optional[List[Span]]:

just to be explicit about the return type

ancestors = []
tracer = self._tracer
assert tracer

run = tracer.run_map.get(str(run_id))
if not run:
return None

run_id = run.parent_run_id # start with the first ancestor

while run_id:
span = self.get_span(run_id)
if span:
ancestors.append(span)

run = tracer.run_map.get(str(run_id))
run_id = run.parent_run_id
return ancestors if ancestors else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the simplicity of the typing (and maybe just my FP brain) - I think an empty array might be preferable to represent "none"?



class _BaseCallbackManagerInit:
__slots__ = ("_tracer",)
Expand Down Expand Up @@ -104,3 +124,20 @@ def get_current_span() -> Optional[Span]:
if not run_id:
return None
return LangChainInstrumentor().get_span(run_id)


def get_ancestor_spans() -> Optional[List[Span]]:
import langchain_core

run_id: Optional[UUID] = None
config = langchain_core.runnables.config.var_child_runnable_config.get()
if not isinstance(config, dict):
return None
for v in config.values():
if not isinstance(v, langchain_core.callbacks.BaseCallbackManager):
continue
if run_id := v.parent_run_id:
break
if not run_id:
return None
mikeldking marked this conversation as resolved.
Show resolved Hide resolved
return LangChainInstrumentor().get_ancestors(run_id)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from opentelemetry import trace as trace_api
from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import Span
from opentelemetry.util.types import AttributeValue
from wrapt import ObjectProxy

Expand Down Expand Up @@ -114,10 +115,10 @@ def __init__(self, tracer: trace_api.Tracer, *args: Any, **kwargs: Any) -> None:
assert self.run_map
self.run_map = _DictWithLock[str, Run](self.run_map)
self._tracer = tracer
self._spans_by_run: Dict[UUID, trace_api.Span] = _DictWithLock[UUID, trace_api.Span]()
self._spans_by_run: Dict[UUID, Span] = _DictWithLock[UUID, Span]()
self._lock = RLock() # handlers may be run in a thread by langchain

def get_span(self, run_id: UUID) -> Optional[trace_api.Span]:
def get_span(self, run_id: UUID) -> Optional[Span]:
return self._spans_by_run.get(run_id)

@audit_timing # type: ignore
Expand All @@ -140,6 +141,7 @@ def _start_trace(self, run: Run) -> None:
context=parent_context,
start_time=start_time_utc_nano,
)

# The following line of code is commented out to serve as a reminder that in a system
# of callbacks, attaching the context can be hazardous because there is no guarantee
# that the context will be detached. An error could happen between callbacks leaving
Expand Down Expand Up @@ -207,7 +209,7 @@ def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:


@audit_timing # type: ignore
def _record_exception(span: trace_api.Span, error: BaseException) -> None:
def _record_exception(span: Span, error: BaseException) -> None:
if isinstance(error, Exception):
span.record_exception(error)
return
Expand All @@ -229,7 +231,7 @@ def _record_exception(span: trace_api.Span, error: BaseException) -> None:


@audit_timing # type: ignore
def _update_span(span: trace_api.Span, run: Run) -> None:
def _update_span(span: Span, run: Run) -> None:
if run.error is None:
span.set_status(trace_api.StatusCode.OK)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from langchain_community.embeddings import FakeEmbeddings
from langchain_community.retrievers import KNNRetriever
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableLambda, RunnableSerializable
from langchain_openai import ChatOpenAI
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import ReadableSpan
Expand All @@ -39,7 +39,10 @@
from respx import MockRouter

from openinference.instrumentation import using_attributes
from openinference.instrumentation.langchain import get_current_span
from openinference.instrumentation.langchain import (
get_ancestor_spans,
get_current_span,
)
from openinference.semconv.trace import (
DocumentAttributes,
EmbeddingAttributes,
Expand Down Expand Up @@ -92,6 +95,94 @@ async def f(_: Any) -> Optional[Span]:
}


async def test_get_ancestor_spans(
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
"""Test retrieving the current chain root span during RunnableLambda execution."""
n = 10 # Number of concurrent runs
loop = asyncio.get_running_loop()

root_spans_during_execution = []

def f(x: int) -> int:
current_span = get_current_span()
root_spans = get_ancestor_spans()
assert root_spans is not None, "Ancestor should not be None during execution (async)"
assert len(root_spans) == 1, "Only get ancestor spans"
assert current_span is not root_spans[0], "Ancestor is distinct from the current span"
root_spans_during_execution.append(root_spans[0])
assert (
root_spans[0].name == "RunnableSequence"
), "RunnableSequence should be the outermost ancestor"
return x + 1

sequence: RunnableSerializable[int, int] = RunnableLambda[int, int](f) | RunnableLambda[
int, int
](f)

with ThreadPoolExecutor() as executor:
tasks = [loop.run_in_executor(executor, sequence.invoke, 1) for _ in range(n)]
await asyncio.gather(*tasks)

root_span_after_execution = get_ancestor_spans()
assert root_span_after_execution is None, "Ancestor should be None after execution"

assert (
len(root_spans_during_execution) == 2 * n
), "Did not capture all ancestors during execution"

assert (
len(set(id(span) for span in root_spans_during_execution)) == n
), "Both Lambdas share the same ancestor"

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 3 * n, f"Expected {3 * n} spans, but found {len(spans)}"


async def test_get_ancestor_spans_async(
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
"""Test retrieving the current chain root span during RunnableLambda execution."""
if sys.version_info < (3, 11):
pytest.xfail("Async test may fail on Python versions below 3.11")
n = 10 # Number of concurrent runs

root_spans_during_execution = []

async def f(x: int) -> int:
current_span = get_current_span()
root_spans = get_ancestor_spans()
assert root_spans is not None, "Ancestor should not be None during execution (async)"
assert len(root_spans) == 1, "Only get ancestor spans"
assert current_span is not root_spans[0], "Ancestor is distinct from the current span"
root_spans_during_execution.append(root_spans[0])
assert (
root_spans[0].name == "RunnableSequence"
), "RunnableSequence should be the outermost ancestor"
await asyncio.sleep(0.01)
return x + 1

sequence: RunnableSerializable[int, int] = RunnableLambda[int, int](f) | RunnableLambda[
int, int
](f)

await asyncio.gather(*(sequence.ainvoke(1) for _ in range(n)))

root_span_after_execution = get_ancestor_spans()
assert root_span_after_execution is None, "Ancestor should be None after execution"

assert (
len(root_spans_during_execution) == 2 * n
), "Did not capture all ancestors during execution"

assert (
len(set(id(span) for span in root_spans_during_execution)) == n
), "Both Lambdas share the same ancestor"

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 3 * n, f"Expected {3 * n} spans, but found {len(spans)}"


@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize("is_stream", [False, True])
@pytest.mark.parametrize("status_code", [200, 400])
Expand Down
Loading