Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Jun 11, 2024
1 parent 81fafd9 commit 6f6e2ee
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,11 @@
_instruments = ("langchain >= 0.0.346", "langchain-core > 0.1.0")

WRAPPED_METHODS = [
{
"package": "langchain.schema.runnable",
"class": "RunnableSequence",
"is_callback": True,
"kind": TraceloopSpanKindValues.TOOL.value,
},
{
"package": "langchain.chains.llm",
"class": "LLMChain",
"is_callback": True,
"kind": TraceloopSpanKindValues.TOOL.value,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.chains.base",
Expand All @@ -59,7 +53,7 @@
"package": "langchain.chains",
"class": "SequentialChain",
"is_callback": True,
"kind": TraceloopSpanKindValues.TOOL.value,
"kind": TraceloopSpanKindValues.WORKFLOW.value,
},
{
"package": "langchain.agents",
Expand Down Expand Up @@ -127,6 +121,20 @@
"method": "ainvoke",
"wrapper": atask_wrapper,
},
{
"package": "langchain.schema.runnable",
"object": "RunnableSequence",
"method": "invoke",
"span_name": "langchain.workflow",
"wrapper": workflow_wrapper,
},
{
"package": "langchain.schema.runnable",
"object": "RunnableSequence",
"method": "ainvoke",
"span_name": "langchain.workflow",
"wrapper": aworkflow_wrapper,
},
{
"package": "langchain_core.language_models.llms",
"object": "LLM",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import json
from typing import Any, Dict
from uuid import UUID

from langchain.schema.runnable import RunnableSequence
from langchain_core.callbacks import BaseCallbackHandler
from opentelemetry import context as context_api
from opentelemetry.trace import set_span_in_context, Tracer
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.semconv.ai import SpanAttributes
from opentelemetry.trace import set_span_in_context, Tracer
from opentelemetry.trace.span import Span

from opentelemetry import context as context_api
from opentelemetry.instrumentation.langchain.utils import (
_with_tracer_wrapper,
)


class CustomJsonEncode(json.JSONEncoder):
def default(self, o):
if isinstance(o, UUID):
def default(self, o: Any) -> str:
try:
return super().default(o)
except TypeError:
return str(o)
return super().default(o)


def get_name(to_wrap, instance) -> str:
Expand All @@ -35,18 +36,14 @@ def callback_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
kind = get_kind(to_wrap)
name = get_name(to_wrap, instance)
cb = SyncSpanCallbackHandler(tracer, name, kind)
if isinstance(instance, RunnableSequence):
# This does not work
instance = instance.with_config(callbacks=[cb, ])
if "callbacks" in kwargs:
if not any(isinstance(c, SyncSpanCallbackHandler) for c in kwargs["callbacks"]):
# Avoid adding the same callback twice, e.g. SequentialChain is also a Chain
kwargs["callbacks"].append(cb)
else:
if "callbacks" in kwargs:
if not any(isinstance(c, SyncSpanCallbackHandler) for c in kwargs["callbacks"]):
# Avoid adding the same callback twice, e.g. SequentialChain is also a Chain
kwargs["callbacks"].append(cb)
else:
kwargs["callbacks"] = [
cb,
]
kwargs["callbacks"] = [
cb,
]
return wrapped(*args, **kwargs)


Expand All @@ -55,7 +52,7 @@ def __init__(self, tracer: Tracer, name: str, kind: str) -> None:
self.tracer = tracer
self.name = name
self.kind = kind
self.span = None
self.span: Span

def _create_span(self) -> None:
self.span = self.tracer.start_span(self.name)
Expand All @@ -68,13 +65,15 @@ def _create_span(self) -> None:
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self._create_span()
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps({"inputs": inputs, "kwargs": kwargs}, cls=CustomJsonEncode),
)

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"outputs": outputs, "kwargs": kwargs}, cls=CustomJsonEncode),
Expand All @@ -84,6 +83,7 @@ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self._create_span()
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
Expand All @@ -93,6 +93,7 @@ def on_tool_start(
)

def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"output": output, "kwargs": kwargs}, cls=CustomJsonEncode),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,5 @@ def test_sequential_chain(exporter):
"LLMChain.langchain.task",
"openai.completion",
"LLMChain.langchain.task",
"SequentialChain.langchain.task",
"SequentialChain.langchain.workflow",
] == [span.name for span in spans]

0 comments on commit 6f6e2ee

Please sign in to comment.