Skip to content

Commit

Permalink
Add generator handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed May 28, 2024
1 parent 3638ab6 commit 660f063
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from opentelemetry.instrumentation.langchain.workflow_wrapper import (
workflow_wrapper,
aworkflow_wrapper,
gworkflow_wrapper,
agworkflow_wrapper,
)
from opentelemetry.instrumentation.langchain.custom_llm_wrapper import (
Expand Down Expand Up @@ -147,7 +148,7 @@
"object": "RunnableSequence",
"method": "stream",
"span_name": "langchain.workflow",
"wrapper": workflow_wrapper,
"wrapper": gworkflow_wrapper,
},
{
"package": "langchain.schema.runnable",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def wrapper(wrapped, instance, args, kwargs):
return _with_tracer


def set_span_attribute(span, name, value):
if value is not None:
if value != "":
span.set_attribute(name, value)
return


def should_send_prompts():
return (
os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,26 @@

from opentelemetry.instrumentation.langchain.utils import (
_with_tracer_wrapper,
dont_throw,
process_request,
process_response,
set_span_attribute,
should_send_prompts,
)


@dont_throw
def build_from_streaming_response(span, return_value):
for idx, item in enumerate(return_value):
yield item
if should_send_prompts():
set_span_attribute(
span,
f"langchain.completion.{idx}.chunk",
str(item),
)


@_with_tracer_wrapper
def workflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
"""Instruments and calls every function defined in TO_WRAP."""
Expand Down Expand Up @@ -62,6 +77,28 @@ async def aworkflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
return return_value


@_with_tracer_wrapper
def gworkflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
"""Instruments and calls every function defined in TO_WRAP."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
yield from wrapped(*args, **kwargs)
else:
name, kind = _handle_request(instance, args, to_wrap)

attach(set_value("workflow_name", name))

with tracer.start_as_current_span(name) as span:
span.set_attribute(
SpanAttributes.TRACELOOP_SPAN_KIND,
kind,
)
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, name)

process_request(span, args, kwargs)
return_value = wrapped(*args, **kwargs)
yield from build_from_streaming_response(span, return_value)


@_with_tracer_wrapper
async def agworkflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
"""Instruments and calls every function defined in TO_WRAP."""
Expand All @@ -81,9 +118,16 @@ async def agworkflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, name)

process_request(span, args, kwargs)
async for i in wrapped(*args, *kwargs):
span.add_event(name="langchain.content.completion.chunk")
yield i
idx = 0
async for item in wrapped(*args, **kwargs):
yield item
if should_send_prompts():
set_span_attribute(
span,
f"langchain.completion.{idx}.chunk",
str(item),
)
idx += 1


def _handle_request(instance, args, to_wrap):
Expand Down
34 changes: 34 additions & 0 deletions packages/sample-app/sample_app/langchain_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import asyncio
import os

from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain_cohere import ChatCohere
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
from traceloop.sdk import Traceloop


def streaming(runnable, input_prompt):
return ''.join(list(runnable.stream(input_prompt)))


async def astreaming(runnable, input_prompt):
return ''.join([i async for i in runnable.astream(input_prompt)])


if __name__ == "__main__":
Traceloop.init(
app_name="streaming_example",
exporter=None if os.getenv("TRACELOOP_API_KEY") else ConsoleSpanExporter(),
)

prompt = ChatPromptTemplate.from_messages(
[("system", "You are a helpful assistant"), ("user", "{input}")]
)
chat = ChatCohere(model="command-r")
parser = StrOutputParser()
runnable = prompt | chat | parser
input_prompt = {"input": "tell me a short joke"}

print(streaming(runnable, input_prompt))
print(asyncio.run(astreaming(runnable, input_prompt)))

0 comments on commit 660f063

Please sign in to comment.