From 660f0637db8bd6d568354bcc0f30a649c0e9cf9a Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Tue, 28 May 2024 21:58:15 +0200 Subject: [PATCH] Add generator handling --- .../instrumentation/langchain/__init__.py | 3 +- .../instrumentation/langchain/utils.py | 7 +++ .../langchain/workflow_wrapper.py | 50 +++++++++++++++++-- .../sample_app/langchain_streaming.py | 34 +++++++++++++ 4 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 packages/sample-app/sample_app/langchain_streaming.py diff --git a/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/__init__.py b/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/__init__.py index 755a4630e7..a0aead93c2 100644 --- a/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/__init__.py +++ b/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/__init__.py @@ -17,6 +17,7 @@ from opentelemetry.instrumentation.langchain.workflow_wrapper import ( workflow_wrapper, aworkflow_wrapper, + gworkflow_wrapper, agworkflow_wrapper, ) from opentelemetry.instrumentation.langchain.custom_llm_wrapper import ( @@ -147,7 +148,7 @@ "object": "RunnableSequence", "method": "stream", "span_name": "langchain.workflow", - "wrapper": workflow_wrapper, + "wrapper": gworkflow_wrapper, }, { "package": "langchain.schema.runnable", diff --git a/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/utils.py b/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/utils.py index d9797b6202..b405f2af9a 100644 --- a/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/utils.py +++ b/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/utils.py @@ -34,6 +34,13 @@ def wrapper(wrapped, instance, args, kwargs): return _with_tracer +def set_span_attribute(span, name, value): + if value is not None: + if value != "": + span.set_attribute(name, value) + return + + def should_send_prompts(): return ( os.getenv("TRACELOOP_TRACE_CONTENT") or "true" diff --git a/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/workflow_wrapper.py b/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/workflow_wrapper.py index 89f3e246e4..9f7826cf27 100644 --- a/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/workflow_wrapper.py +++ b/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/workflow_wrapper.py @@ -9,11 +9,26 @@ from opentelemetry.instrumentation.langchain.utils import ( _with_tracer_wrapper, + dont_throw, process_request, process_response, + set_span_attribute, + should_send_prompts, ) +@dont_throw +def build_from_streaming_response(span, return_value): + for idx, item in enumerate(return_value): + yield item + if should_send_prompts(): + set_span_attribute( + span, + f"langchain.completion.{idx}.chunk", + str(item), + ) + + @_with_tracer_wrapper def workflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs): """Instruments and calls every function defined in TO_WRAP.""" @@ -62,6 +77,28 @@ async def aworkflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs): return return_value +@_with_tracer_wrapper +def gworkflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs): + """Instruments and calls every function defined in TO_WRAP.""" + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + yield from wrapped(*args, **kwargs) + else: + name, kind = _handle_request(instance, args, to_wrap) + + attach(set_value("workflow_name", name)) + + with tracer.start_as_current_span(name) as span: + span.set_attribute( + SpanAttributes.TRACELOOP_SPAN_KIND, + kind, + ) + span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, name) + + process_request(span, args, kwargs) + return_value = wrapped(*args, **kwargs) + yield from build_from_streaming_response(span, return_value) + + @_with_tracer_wrapper async def agworkflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs): """Instruments and calls every function defined in TO_WRAP.""" @@ -81,9 +118,16 @@ async def agworkflow_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs): span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, name) process_request(span, args, kwargs) - async for i in wrapped(*args, *kwargs): - span.add_event(name="langchain.content.completion.chunk") - yield i + idx = 0 + async for item in wrapped(*args, **kwargs): + yield item + if should_send_prompts(): + set_span_attribute( + span, + f"langchain.completion.{idx}.chunk", + str(item), + ) + idx += 1 def _handle_request(instance, args, to_wrap): diff --git a/packages/sample-app/sample_app/langchain_streaming.py b/packages/sample-app/sample_app/langchain_streaming.py new file mode 100644 index 0000000000..d7da157d6c --- /dev/null +++ b/packages/sample-app/sample_app/langchain_streaming.py @@ -0,0 +1,34 @@ +import asyncio +import os + +from langchain.prompts import ChatPromptTemplate +from langchain.schema import StrOutputParser +from langchain_cohere import ChatCohere +from opentelemetry.sdk.trace.export import ConsoleSpanExporter +from traceloop.sdk import Traceloop + + +def streaming(runnable, input_prompt): + return ''.join(list(runnable.stream(input_prompt))) + + +async def astreaming(runnable, input_prompt): + return ''.join([i async for i in runnable.astream(input_prompt)]) + + +if __name__ == "__main__": + Traceloop.init( + app_name="streaming_example", + exporter=None if os.getenv("TRACELOOP_API_KEY") else ConsoleSpanExporter(), + ) + + prompt = ChatPromptTemplate.from_messages( + [("system", "You are a helpful assistant"), ("user", "{input}")] + ) + chat = ChatCohere(model="command-r") + parser = StrOutputParser() + runnable = prompt | chat | parser + input_prompt = {"input": "tell me a short joke"} + + print(streaming(runnable, input_prompt)) + print(asyncio.run(astreaming(runnable, input_prompt)))