Skip to content

Commit

Permalink
Add more classes
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Jun 22, 2024
1 parent 3407b65 commit b22b0a1
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 185 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap

from opentelemetry.instrumentation.langchain.task_wrapper import (
task_wrapper,
atask_wrapper,
)
from opentelemetry.instrumentation.langchain.task_wrapper import task_wrapper
from opentelemetry.instrumentation.langchain.workflow_wrapper import (
workflow_wrapper,
aworkflow_wrapper,
Expand All @@ -22,10 +19,6 @@
llm_wrapper,
allm_wrapper,
)
from opentelemetry.instrumentation.langchain.custom_chat_wrapper import (
chat_wrapper,
achat_wrapper,
)
from opentelemetry.instrumentation.langchain.version import __version__

from opentelemetry.semconv.ai import TraceloopSpanKindValues
Expand All @@ -41,14 +34,32 @@
"package": "langchain.chains.base",
"class": "Chain",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
"kind": TraceloopSpanKindValues.WORKFLOW.value,
},
{
"package": "langchain.chains",
"class": "SequentialChain",
"package": "langchain.schema.runnable",
"class": "RunnableSequence",
"is_callback": True,
"kind": TraceloopSpanKindValues.WORKFLOW.value,
},
{
"package": "langchain.prompts.base",
"class": "BasePromptTemplate",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.chat_models.base",
"class": "BaseChatModel",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.schema",
"class": "BaseOutputParser",
"is_callback": True,
"kind": TraceloopSpanKindValues.TASK.value,
},
{
"package": "langchain.agents",
"object": "AgentExecutor",
Expand Down Expand Up @@ -79,56 +90,6 @@
"span_name": "retrieval_qa.workflow",
"wrapper": aworkflow_wrapper,
},
{
"package": "langchain.prompts.base",
"object": "BasePromptTemplate",
"method": "invoke",
"wrapper": task_wrapper,
},
{
"package": "langchain.prompts.base",
"object": "BasePromptTemplate",
"method": "ainvoke",
"wrapper": atask_wrapper,
},
{
"package": "langchain.chat_models.base",
"object": "BaseChatModel",
"method": "generate",
"wrapper": chat_wrapper,
},
{
"package": "langchain.chat_models.base",
"object": "BaseChatModel",
"method": "agenerate",
"wrapper": achat_wrapper,
},
{
"package": "langchain.schema",
"object": "BaseOutputParser",
"method": "invoke",
"wrapper": task_wrapper,
},
{
"package": "langchain.schema",
"object": "BaseOutputParser",
"method": "ainvoke",
"wrapper": atask_wrapper,
},
{
"package": "langchain.schema.runnable",
"object": "RunnableSequence",
"method": "invoke",
"span_name": "langchain.workflow",
"wrapper": workflow_wrapper,
},
{
"package": "langchain.schema.runnable",
"object": "RunnableSequence",
"method": "ainvoke",
"span_name": "langchain.workflow",
"wrapper": aworkflow_wrapper,
},
{
"package": "langchain_core.language_models.llms",
"object": "LLM",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
from dataclasses import dataclass
from typing import Any, Optional
from uuid import UUID

from langchain_core.callbacks import BaseCallbackHandler, BaseCallbackManager
from langchain_core.messages import BaseMessage
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.semconv.ai import SpanAttributes
from opentelemetry.semconv.ai import LLMRequestTypeValues, SpanAttributes
from opentelemetry.context.context import Context
from opentelemetry.trace import set_span_in_context, Tracer
from opentelemetry.trace.span import Span

Expand All @@ -23,10 +26,37 @@ def default(self, o: Any) -> str:
return str(o)


@dataclass
class Instance:
kind: str
class_name: str
run_name: Optional[str]

@classmethod
def from_args(cls, kind: str, instance_name: str, args: Any) -> "Instance":
if len(args) > 1 and "run_name" in args[1]:
run_name = args[1]["run_name"]
else:
run_name = None
return cls(kind, instance_name, run_name)

def get_kind(self) -> str:
return self.kind

def get_name(self) -> str:
return f"{self.run_name if self.run_name is not None else self.class_name}.langchain.{self.kind}"


@dataclass
class SpanHolder:
span: Span
token: Any
context: Context
children: list[UUID]


@dont_throw
def _add_callback(tracer, to_wrap, instance, args):
kind = to_wrap.get("kind")
class_name = instance.get_name()
cb = SyncSpanCallbackHandler(tracer)
if len(args) > 1:
if "callbacks" in args[1]:
Expand All @@ -47,7 +77,9 @@ def _add_callback(tracer, to_wrap, instance, args):
args[1]["callbacks"].append(cb)
else:
args[1].update({"callbacks": [cb, ]})
cb.add_kind(class_name, kind)
cb.add_instance_params(
Instance.from_args(to_wrap.get("kind"), instance.get_name(), args)
)
return cb


Expand All @@ -66,29 +98,54 @@ def callback_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
return wrapped(*args, {"callbacks": [cb, ]}, **kwargs)


def _set_chat_request(span: Span, serialized: Any) -> None:
span.set_attribute(SpanAttributes.LLM_REQUEST_TYPE, LLMRequestTypeValues.CHAT.value)
try:
kwargs = serialized["kwargs"]
for model_tag in ("model", "model_id", "model_name"):
if (model := kwargs.get(model_tag)) is not None:
break
else:
model = "unknown"
span.set_attribute(SpanAttributes.LLM_REQUEST_MODEL, model)
except KeyError:
pass


class SyncSpanCallbackHandler(BaseCallbackHandler):
def __init__(self, tracer: Tracer) -> None:
self.tracer = tracer
self.kinds = {}
self.spans = {}
self.instances = {}
self.spans: dict[UUID, SpanHolder] = {}

def add_kind(self, class_name: str, kind: str) -> None:
if class_name not in self.kinds:
self.kinds[class_name] = (f"{class_name}.langchain.{kind}", kind)
def add_instance_params(self, instance: Instance) -> None:
self.instances[instance.class_name] = instance

def _get_span(self, run_id: UUID):
return self.spans[run_id]
def _get_span(self, run_id: UUID) -> Span:
return self.spans[run_id].span

def _create_span(self, serialized: dict[str, Any], run_id: UUID) -> Span:
name, kind = self.kinds[serialized["name"]]
span = self.tracer.start_span(name)
span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, kind)
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, name)
self.spans[run_id] = span
def _end_span(self, span: Span, run_id: UUID) -> None:
for child_id in self.spans[run_id].children:
child_span = self.spans[child_id].span
if child_span.end_time is None: # avoid warning on ended spans
child_span.end()
span.end()

def _create_span(self, run_id: UUID, parent_run_id: Optional[UUID], class_name: str) -> Span:
instance = self.instances[class_name]
kind = instance.get_kind()
name = instance.get_name()
if parent_run_id is not None:
span = self.tracer.start_span(name, context=self.spans[parent_run_id].context)
else:
span = self.tracer.start_span(name)
span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, kind)
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, f"{name}.langchain.{kind}")
current_context = set_span_in_context(span)
context_api.attach(current_context)

token = context_api.attach(current_context)
self.spans[run_id] = SpanHolder(span, token, current_context, [])
if parent_run_id is not None:
self.spans[parent_run_id].children.append(run_id)
return span

def on_chain_start(
Expand All @@ -103,10 +160,19 @@ def on_chain_start(
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
span = self._create_span(serialized, run_id)
class_name = serialized["id"][-1]
span = self._create_span(run_id, parent_run_id, class_name)
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps({"inputs": inputs, "kwargs": kwargs}, cls=CustomJsonEncode),
json.dumps(
{
"inputs": inputs,
"tags": tags,
"metadata": metadata,
"kwargs": kwargs,
},
cls=CustomJsonEncode,
),
)

def on_chain_end(
Expand All @@ -123,7 +189,36 @@ def on_chain_end(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"outputs": outputs, "kwargs": kwargs}, cls=CustomJsonEncode),
)
span.end()
self._end_span(span, run_id)

def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when Chat Model starts running."""
class_name = serialized["id"][-1]
span = self._create_span(run_id, parent_run_id, class_name)
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps(
{
"messages": messages,
"metadata": metadata,
"name": name,
"kwargs": kwargs,
},
cls=CustomJsonEncode,
),
)
_set_chat_request(span, serialized)

def on_tool_start(
self,
Expand All @@ -138,11 +233,19 @@ def on_tool_start(
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
span = self._create_span(serialized, run_id)
class_name = serialized["id"][-1]
span = self._create_span(run_id, parent_run_id, class_name)
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps(
{"input_str": input_str, "kwargs": kwargs}, cls=CustomJsonEncode,
{
"input_str": input_str,
"tags": tags,
"metadata": metadata,
"inputs": inputs,
"kwargs": kwargs,
},
cls=CustomJsonEncode,
),
)

Expand All @@ -160,4 +263,4 @@ def on_tool_end(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps({"output": output, "kwargs": kwargs}, cls=CustomJsonEncode),
)
span.end()
self._end_span(span, run_id)
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ def test_sequential_chain(exporter):

assert [
"openai.completion",
"LLMChain.langchain.task",
"LLMChain.langchain.workflow",
"openai.completion",
"LLMChain.langchain.task",
"LLMChain.langchain.workflow",
"SequentialChain.langchain.workflow",
] == [span.name for span in spans]

synopsis_span, review_span = [span for span in spans if span.name == "LLMChain.langchain.task"]
synopsis_span, review_span = [span for span in spans if span.name == "LLMChain.langchain.workflow"]

data = json.loads(synopsis_span.attributes[SpanAttributes.TRACELOOP_ENTITY_INPUT])
assert data["inputs"] == {'title': 'Tragedy at sunset on the beach', 'era': 'Victorian England'}
Expand Down Expand Up @@ -112,13 +112,13 @@ async def test_asequential_chain(exporter):

assert [
"openai.completion",
"LLMChain.langchain.task",
"LLMChain.langchain.workflow",
"openai.completion",
"LLMChain.langchain.task",
"LLMChain.langchain.workflow",
"SequentialChain.langchain.workflow",
] == [span.name for span in spans]

synopsis_span, review_span = [span for span in spans if span.name == "LLMChain.langchain.task"]
synopsis_span, review_span = [span for span in spans if span.name == "LLMChain.langchain.workflow"]

data = json.loads(synopsis_span.attributes[SpanAttributes.TRACELOOP_ENTITY_INPUT])
assert data["inputs"] == {'title': 'Tragedy at sunset on the beach', 'era': 'Victorian England'}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ def test_sequential_chain(exporter):
spans = exporter.get_finished_spans()

assert [
"ChatCohere.langchain.task",
"LLMChain.langchain.task",
"StuffDocumentsChain.langchain.task",
"LLMChain.langchain.workflow",
"StuffDocumentsChain.langchain.workflow",
] == [span.name for span in spans]

stuff_span = next(span for span in spans if span.name == "StuffDocumentsChain.langchain.task")
stuff_span = next(span for span in spans if span.name == "StuffDocumentsChain.langchain.workflow")

data = json.loads(stuff_span.attributes[SpanAttributes.TRACELOOP_ENTITY_INPUT])
assert data["inputs"].keys() == {"input_documents"}
Expand Down
Loading

0 comments on commit b22b0a1

Please sign in to comment.