Skip to content

Commit

Permalink
feat(langchain): use callbacks (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Jun 16, 2024
1 parent aaa303b commit 995d8b6
Show file tree
Hide file tree
Showing 9 changed files with 802 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,36 @@

from opentelemetry.semconv.ai import TraceloopSpanKindValues

from opentelemetry.instrumentation.langchain.callback_wrapper import callback_wrapper

logger = logging.getLogger(__name__)

_instruments = ("langchain >= 0.0.346", "langchain-core > 0.1.0")

WRAPPED_METHODS = [
{
"package": "langchain.chains.base",
"object": "Chain",
"method": "__call__",
"wrapper": task_wrapper,
"class": "Chain",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.chains.base",
"object": "Chain",
"method": "acall",
"wrapper": atask_wrapper,
"package": "langchain.chains.llm",
"class": "LLMChain",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.chains",
"object": "SequentialChain",
"method": "__call__",
"span_name": "langchain.workflow",
"wrapper": workflow_wrapper,
"package": "langchain.chains.combine_documents.stuff",
"class": "StuffDocumentsChain",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.chains",
"object": "SequentialChain",
"method": "acall",
"span_name": "langchain.workflow",
"wrapper": aworkflow_wrapper,
"class": "SequentialChain",
"is_callback": True,
"kind": TraceloopSpanKindValues.WORKFLOW.value,
},
{
"package": "langchain.agents",
Expand Down Expand Up @@ -173,21 +173,33 @@ def _instrument(self, **kwargs):
tracer = get_tracer(__name__, __version__, tracer_provider)
for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
wrapper = wrapped_method.get("wrapper")
wrap_function_wrapper(
wrap_package,
f"{wrap_object}.{wrap_method}" if wrap_object else wrap_method,
wrapper(tracer, wrapped_method),
)
if wrapped_method.get("is_callback"):
wrap_class = wrapped_method.get("class")
wrap_function_wrapper(
wrap_package,
f"{wrap_class}.__init__",
callback_wrapper(tracer, wrapped_method),
)
else:
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
wrapper = wrapped_method.get("wrapper")
wrap_function_wrapper(
wrap_package,
f"{wrap_object}.{wrap_method}" if wrap_object else wrap_method,
wrapper(tracer, wrapped_method),
)

def _uninstrument(self, **kwargs):
for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
unwrap(
f"{wrap_package}.{wrap_object}" if wrap_object else wrap_package,
wrap_method,
)
if wrapped_method.get("is_callback"):
wrap_class = wrapped_method.get("class")
unwrap(wrap_package, f"{wrap_class}.__init__")
else:
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
unwrap(
f"{wrap_package}.{wrap_object}" if wrap_object else wrap_package,
wrap_method,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import json
from typing import Any, Dict

from langchain_core.callbacks import BaseCallbackHandler
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.semconv.ai import SpanAttributes
from opentelemetry.trace import set_span_in_context, Tracer
from opentelemetry.trace.span import Span

from opentelemetry import context as context_api
from opentelemetry.instrumentation.langchain.utils import (
_with_tracer_wrapper,
)


class CustomJsonEncode(json.JSONEncoder):
def default(self, o: Any) -> str:
try:
return super().default(o)
except TypeError:
return str(o)


def get_name(to_wrap, instance) -> str:
return f"{instance.get_name()}.langchain.{to_wrap.get('kind')}"


def get_kind(to_wrap) -> str:
return to_wrap.get("kind")


@_with_tracer_wrapper
def callback_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)
kind = get_kind(to_wrap)
name = get_name(to_wrap, instance)
cb = SyncSpanCallbackHandler(tracer, name, kind)
if "callbacks" in kwargs:
if not any(isinstance(c, SyncSpanCallbackHandler) for c in kwargs["callbacks"]):
# Avoid adding the same callback twice, e.g. SequentialChain is also a Chain
kwargs["callbacks"].append(cb)
else:
kwargs["callbacks"] = [
cb,
]
return wrapped(*args, **kwargs)


class SyncSpanCallbackHandler(BaseCallbackHandler):
def __init__(self, tracer: Tracer, name: str, kind: str) -> None:
self.tracer = tracer
self.name = name
self.kind = kind
self.span: Span

def _create_span(self) -> None:
self.span = self.tracer.start_span(self.name)
self.span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, self.kind)
self.span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, self.name)

current_context = set_span_in_context(self.span)
context_api.attach(current_context)

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self._create_span()
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps({"inputs": inputs, "kwargs": kwargs}, cls=CustomJsonEncode),
)

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"outputs": outputs, "kwargs": kwargs}, cls=CustomJsonEncode),
)
self.span.end()

def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self._create_span()
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps(
{"input_str": input_str, "kwargs": kwargs}, cls=CustomJsonEncode
),
)

def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"output": output, "kwargs": kwargs}, cls=CustomJsonEncode),
)
self.span.end()
Loading

0 comments on commit 995d8b6

Please sign in to comment.