Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(langchain): use callbacks #1170

Merged
merged 5 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,36 @@

from opentelemetry.semconv.ai import TraceloopSpanKindValues

from opentelemetry.instrumentation.langchain.callback_wrapper import callback_wrapper

logger = logging.getLogger(__name__)

_instruments = ("langchain >= 0.0.346", "langchain-core > 0.1.0")

WRAPPED_METHODS = [
{
"package": "langchain.chains.base",
"object": "Chain",
"method": "__call__",
"wrapper": task_wrapper,
"class": "Chain",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.chains.base",
"object": "Chain",
"method": "acall",
"wrapper": atask_wrapper,
"package": "langchain.chains.llm",
"class": "LLMChain",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.chains",
"object": "SequentialChain",
"method": "__call__",
"span_name": "langchain.workflow",
"wrapper": workflow_wrapper,
"package": "langchain.chains.combine_documents.stuff",
"class": "StuffDocumentsChain",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.chains",
"object": "SequentialChain",
"method": "acall",
"span_name": "langchain.workflow",
"wrapper": aworkflow_wrapper,
"class": "SequentialChain",
"is_callback": True,
"kind": TraceloopSpanKindValues.WORKFLOW.value,
},
{
"package": "langchain.agents",
Expand Down Expand Up @@ -173,21 +173,33 @@ def _instrument(self, **kwargs):
tracer = get_tracer(__name__, __version__, tracer_provider)
for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
wrapper = wrapped_method.get("wrapper")
wrap_function_wrapper(
wrap_package,
f"{wrap_object}.{wrap_method}" if wrap_object else wrap_method,
wrapper(tracer, wrapped_method),
)
if wrapped_method.get("is_callback"):
wrap_class = wrapped_method.get("class")
wrap_function_wrapper(
wrap_package,
f"{wrap_class}.__init__",
callback_wrapper(tracer, wrapped_method),
)
else:
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
wrapper = wrapped_method.get("wrapper")
wrap_function_wrapper(
wrap_package,
f"{wrap_object}.{wrap_method}" if wrap_object else wrap_method,
wrapper(tracer, wrapped_method),
)

def _uninstrument(self, **kwargs):
for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
unwrap(
f"{wrap_package}.{wrap_object}" if wrap_object else wrap_package,
wrap_method,
)
if wrapped_method.get("is_callback"):
wrap_class = wrapped_method.get("class")
unwrap(wrap_package, f"{wrap_class}.__init__")
else:
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
unwrap(
f"{wrap_package}.{wrap_object}" if wrap_object else wrap_package,
wrap_method,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import json
from typing import Any, Dict

from langchain_core.callbacks import BaseCallbackHandler
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.semconv.ai import SpanAttributes
from opentelemetry.trace import set_span_in_context, Tracer
from opentelemetry.trace.span import Span

from opentelemetry import context as context_api
from opentelemetry.instrumentation.langchain.utils import (
_with_tracer_wrapper,
)


class CustomJsonEncode(json.JSONEncoder):
def default(self, o: Any) -> str:
try:
return super().default(o)
except TypeError:
return str(o)


def get_name(to_wrap, instance) -> str:
return f"{instance.get_name()}.langchain.{to_wrap.get('kind')}"


def get_kind(to_wrap) -> str:
return to_wrap.get("kind")


@_with_tracer_wrapper
def callback_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)
kind = get_kind(to_wrap)
name = get_name(to_wrap, instance)
cb = SyncSpanCallbackHandler(tracer, name, kind)
if "callbacks" in kwargs:
if not any(isinstance(c, SyncSpanCallbackHandler) for c in kwargs["callbacks"]):
# Avoid adding the same callback twice, e.g. SequentialChain is also a Chain
kwargs["callbacks"].append(cb)
else:
kwargs["callbacks"] = [
cb,
]
return wrapped(*args, **kwargs)


class SyncSpanCallbackHandler(BaseCallbackHandler):
def __init__(self, tracer: Tracer, name: str, kind: str) -> None:
self.tracer = tracer
self.name = name
self.kind = kind
self.span: Span

def _create_span(self) -> None:
self.span = self.tracer.start_span(self.name)
self.span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, self.kind)
self.span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, self.name)

current_context = set_span_in_context(self.span)
context_api.attach(current_context)

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self._create_span()
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps({"inputs": inputs, "kwargs": kwargs}, cls=CustomJsonEncode),
)

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"outputs": outputs, "kwargs": kwargs}, cls=CustomJsonEncode),
)
self.span.end()

def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self._create_span()
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps(
{"input_str": input_str, "kwargs": kwargs}, cls=CustomJsonEncode
),
)

def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"output": output, "kwargs": kwargs}, cls=CustomJsonEncode),
)
self.span.end()
Loading