Skip to content

Commit

Permalink
Trials with run_id
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Jun 19, 2024
1 parent e6626f7 commit 4fd020a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def _instrument(self, **kwargs):
f"{wrap_class}.invoke",
callback_wrapper(tracer, wrapped_method),
)
wrap_function_wrapper(
wrap_package,
f"{wrap_class}.ainvoke",
callback_wrapper(tracer, wrapped_method),
)
else:
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Any, Dict
from typing import Any, Optional
from uuid import UUID

from langchain_core.callbacks import BaseCallbackHandler, BaseCallbackManager
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
Expand All @@ -13,9 +14,6 @@
)


TAG_PREFIX = "tag_openllmetry"


class CustomJsonEncode(json.JSONEncoder):
def default(self, o: Any) -> str:
try:
Expand All @@ -24,121 +22,134 @@ def default(self, o: Any) -> str:
return str(o)


def get_name(to_wrap, instance) -> str:
return f"{instance.get_name()}.langchain.{to_wrap.get('kind')}"


def get_kind(to_wrap) -> str:
return to_wrap.get("kind")


@_with_tracer_wrapper
def callback_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
"""Hook into the invoke function. Note: config is part of args, 2nd place.
"""Hook into the invoke function, config is part of args, 2nd place.
sources:
https://python.langchain.com/v0.2/docs/how_to/callbacks_attach/
https://python.langchain.com/v0.2/docs/how_to/callbacks_runtime/
"""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)
# Add tag
if instance.tags is None:
instance.tags = []
instance.tags.append(f"{TAG_PREFIX}{id(instance)}")
kind = get_kind(to_wrap)
name = get_name(to_wrap, instance)
kind = to_wrap.get("kind")
class_name = instance.get_name()
cb = SyncSpanCallbackHandler(tracer)
if len(args) > 1:
if "callbacks" in args[1]:
temp_list = args[1]["callbacks"]
#if isinstance(temp_list, list):
# if not any(isinstance(c, SyncSpanCallbackHandler) for c in temp_list):
# args[1]["callbacks"].append(cb)
if isinstance(temp_list, BaseCallbackManager):
for c in temp_list.handlers:
if isinstance(c, SyncSpanCallbackHandler):
cb = c
break
else:
args[1]["callbacks"].add_handler(cb)
elif isinstance(temp_list, list):
for c in temp_list:
if isinstance(c, SyncSpanCallbackHandler):
cb = c
break
else:
args[1]["callbacks"].append(cb)
else:
args[1].update({"callbacks": [cb, ]})
cb.add_handler(id(instance), name, kind)
cb.add_kind(class_name, kind)
return wrapped(*args, **kwargs)
else:
cb.add_handler(id(instance), name, kind)
cb.add_kind(class_name, kind)
return wrapped(*args, {"callbacks": [cb, ]}, **kwargs)


class SyncSpanCallbackHandler(BaseCallbackHandler):
def __init__(self, tracer: Tracer) -> None:
self.tracer = tracer
self.handlers = {}
self.kinds = {}
self.spans = {}

@staticmethod
def _get_handler_id(tags: list[str]) -> int:
for tag in tags:
if tag.startswith(TAG_PREFIX):
return int(tag.removeprefix(TAG_PREFIX))
raise RuntimeError
def add_kind(self, class_name: str, kind: str) -> None:
if class_name not in self.kinds:
self.kinds[class_name] = (f"{class_name}.langchain.{kind}", kind)

def _get_span(self, handler_id) -> Span:
return self.spans[handler_id]
def _get_span(self, run_id: UUID):
return self.spans[run_id]

def _create_span(self, handler_id) -> Span:
handler_params = self.handlers[handler_id]
span = self.tracer.start_span(handler_params[0])
span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, handler_params[1])
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, handler_params[0])
self.spans[handler_id] = span
def _create_span(self, serialized: dict[str, Any], run_id: UUID) -> Span:
name, kind = self.kinds[serialized["name"]]
span = self.tracer.start_span(name)
span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, kind)
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, name)
self.spans[run_id] = span

current_context = set_span_in_context(span)
context_api.attach(current_context)

return span

def add_handler(self, handler_id: int, name: str, kind: str) -> None:
self.handlers[handler_id] = (name, kind)

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
handler_id = self._get_handler_id(kwargs["tags"])
span = self._create_span(handler_id)
span = self._create_span(serialized, run_id)
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps({"inputs": inputs, "kwargs": kwargs}, cls=CustomJsonEncode),
)

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(
self,
outputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
handler_id = self._get_handler_id(kwargs["tags"])
span = self._get_span(handler_id)
span = self._get_span(run_id)
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"outputs": outputs, "kwargs": kwargs}, cls=CustomJsonEncode),
)
span.end()

def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
self,
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
handler_id = self._get_handler_id(kwargs["tags"])
span = self._create_span(handler_id)
span = self._create_span(serialized, run_id)
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps(
{"input_str": input_str, "kwargs": kwargs}, cls=CustomJsonEncode
{"input_str": input_str, "kwargs": kwargs}, cls=CustomJsonEncode,
),
)

def on_tool_end(self, output: Any, **kwargs: Any) -> None:
def on_tool_end(
self,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
handler_id = self._get_handler_id(kwargs["tags"])
span = self._get_span(handler_id)
span = self._get_span(run_id)
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"output": output, "kwargs": kwargs}, cls=CustomJsonEncode),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_sequential_chain(exporter):
output_variables=["synopsis", "review"],
verbose=True,
)
overall_chain(
overall_chain.invoke(
{"title": "Tragedy at sunset on the beach", "era": "Victorian England"}
)

Expand Down

0 comments on commit 4fd020a

Please sign in to comment.