Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(langchain): add support for streamed calls #10672

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
186 changes: 186 additions & 0 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ddtrace.contrib.internal.langchain.constants import agent_output_parser_classes
from ddtrace.contrib.internal.langchain.constants import text_embedding_models
from ddtrace.contrib.internal.langchain.constants import vectorstore_classes
from ddtrace.contrib.internal.langchain.utils import shared_stream
from ddtrace.contrib.trace_utils import unwrap
from ddtrace.contrib.trace_utils import with_traced_module
from ddtrace.contrib.trace_utils import wrap
Expand Down Expand Up @@ -975,6 +976,165 @@ def traced_similarity_search(langchain, pin, func, instance, args, kwargs):
return documents


# TODO refactor some of these on_span_started/on_span_finished functions
sabrenner marked this conversation as resolved.
Show resolved Hide resolved
sabrenner marked this conversation as resolved.
Show resolved Hide resolved
# that are used in other patched methods in this file into the utils module
@with_traced_module
def traced_chain_stream(langchain, pin, func, instance, args, kwargs):
integration: LangChainIntegration = langchain._datadog_integration

def _on_span_started(span: Span):
inputs = get_argument_value(args, kwargs, 0, "input")
if integration.is_pc_sampled_span(span):
if not isinstance(inputs, list):
inputs = [inputs]
for idx, inp in enumerate(inputs):
if not isinstance(inp, dict):
span.set_tag_str("langchain.request.inputs.%d" % idx, integration.trunc(str(inp)))
else:
for k, v in inp.items():
span.set_tag_str("langchain.request.inputs.%d.%s" % (idx, k), integration.trunc(str(v)))

def _on_span_finished(span: Span, streamed_chunks, error: Optional[bool] = None):
if not error and integration.is_pc_sampled_span(span):
if langchain_core and isinstance(instance.steps[-1], langchain_core.output_parsers.JsonOutputParser):
# it's possible that the chain has a json output parser
# this will have already concatenated the chunks into a json object

# it's also possible the json output parser isn't the last step,
# but one of the last steps, in which case we won't act on it here
# TODO (sam.brenner) make this more robust
content = json.dumps(streamed_chunks[-1])
else:
# best effort to join chunks together
content = "".join([str(chunk) for chunk in streamed_chunks])
span.set_tag_str("langchain.response.content", integration.trunc(content))

return shared_stream(
integration=integration,
pin=pin,
func=func,
instance=instance,
args=args,
kwargs=kwargs,
interface_type="chain",
on_span_started=_on_span_started,
on_span_finished=_on_span_finished,
)


@with_traced_module
def traced_chat_stream(langchain, pin, func, instance, args, kwargs):
integration: LangChainIntegration = langchain._datadog_integration
llm_provider = instance._llm_type

def _on_span_started(span: Span):
chat_messages = get_argument_value(args, kwargs, 0, "input")
if not isinstance(chat_messages, list):
chat_messages = [chat_messages]
for message_idx, message in enumerate(chat_messages):
Copy link
Contributor Author

@sabrenner sabrenner Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added all of this logic because, unlike model.generate, which just takes in a list of BaseMessage types (ie HumanMessage, SystemMessage, etc.), model.stream can take in:

  1. a single string
  2. a single dict
  3. a list of strings
  4. a list of dicts
  5. a list of BaseMessage types
  6. a PromptValue type, which has a messages property of BaseMessage types

ref

Do we care about all of these different types? Or, do we just want to listify and str each element (this logic would also carry over to LLMObs spans, future PR)? It would make the code a lot simpler, but maybe the view of each tag not as nice.

if integration.is_pc_sampled_span(span):
if isinstance(message, dict):
span.set_tag_str(
"langchain.request.messages.%d.content" % (message_idx),
integration.trunc(str(message.get("content", ""))),
)
span.set_tag_str(
"langchain.request.messages.%d.role" % (message_idx),
str(message.get("role", "")),
)
elif isinstance(message, langchain_core.prompt_values.PromptValue):
for langchain_message_idx, langchain_message in enumerate(message.messages):
span.set_tag_str(
"langchain.request.messages.%d.%d.content" % (message_idx, langchain_message_idx),
integration.trunc(str(langchain_message.content)),
)
span.set_tag_str(
"langchain.request.messages.%d.%d.role" % (message_idx, langchain_message_idx),
str(langchain_message.__class__.__name__),
)
elif isinstance(message, langchain_core.messages.BaseMessage):
span.set_tag_str(
"langchain.request.messages.%d.content" % (message_idx), integration.trunc(str(message.content))
)
span.set_tag_str(
"langchain.request.messages.%d.role" % (message_idx), str(message.__class__.__name__)
)
else:
span.set_tag_str(
"langchain.request.messages.%d.content" % (message_idx), integration.trunc(message)
)

for param, val in getattr(instance, "_identifying_params", {}).items():
if isinstance(val, dict):
for k, v in val.items():
span.set_tag_str("langchain.request.%s.parameters.%s.%s" % (llm_provider, param, k), str(v))
else:
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))

def _on_span_finished(span: Span, streamed_chunks, error: Optional[bool] = None):
if not error and integration.is_pc_sampled_span(span):
content = "".join([str(chunk.content) for chunk in streamed_chunks])
span.set_tag_str("langchain.response.content", integration.trunc(content))

usage = getattr(streamed_chunks[-1], "usage_metadata", None)
if usage:
for k, v in usage.items():
span.set_tag_str("langchain.response.usage_metadata.%s" % k, str(v))

return shared_stream(
integration=integration,
pin=pin,
func=func,
instance=instance,
args=args,
kwargs=kwargs,
interface_type="chat_model",
on_span_started=_on_span_started,
on_span_finished=_on_span_finished,
api_key=_extract_api_key(instance),
provider=llm_provider,
)


@with_traced_module
def traced_llm_stream(langchain, pin, func, instance, args, kwargs):
integration: LangChainIntegration = langchain._datadog_integration
llm_provider = instance._llm_type

def _on_span_start(span: Span):
inp = get_argument_value(args, kwargs, 0, "input")
if not isinstance(inp, list):
inp = [inp]
if integration.is_pc_sampled_span(span):
for idx, prompt in enumerate(inp):
span.set_tag_str("langchain.request.prompts.%d" % idx, integration.trunc(str(prompt)))
for param, val in getattr(instance, "_identifying_params", {}).items():
if isinstance(val, dict):
for k, v in val.items():
span.set_tag_str("langchain.request.%s.parameters.%s.%s" % (llm_provider, param, k), str(v))
else:
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))

def _on_span_finished(span: Span, streamed_chunks, error: Optional[bool] = None):
if not error and integration.is_pc_sampled_span(span):
content = "".join([str(chunk) for chunk in streamed_chunks])
span.set_tag_str("langchain.response.content", integration.trunc(content))

return shared_stream(
integration=integration,
pin=pin,
func=func,
instance=instance,
args=args,
kwargs=kwargs,
interface_type="llm",
on_span_started=_on_span_start,
on_span_finished=_on_span_finished,
api_key=_extract_api_key(instance),
provider=llm_provider,
)


@with_traced_module
def traced_base_tool_invoke(langchain, pin, func, instance, args, kwargs):
integration = langchain._datadog_integration
Expand Down Expand Up @@ -1203,6 +1363,9 @@ def patch():
wrap("langchain", "embeddings.OpenAIEmbeddings.embed_documents", traced_embedding(langchain))
else:
from langchain.chains.base import Chain # noqa:F401
from langchain_core import messages # noqa: F401
from langchain_core import output_parsers # noqa: F401
from langchain_core import prompt_values # noqa: F401
from langchain_core.tools import BaseTool # noqa:F401

wrap("langchain_core", "language_models.llms.BaseLLM.generate", traced_llm_generate(langchain))
Expand All @@ -1225,6 +1388,23 @@ def patch():
)
wrap("langchain_core", "runnables.base.RunnableSequence.batch", traced_lcel_runnable_sequence(langchain))
wrap("langchain_core", "runnables.base.RunnableSequence.abatch", traced_lcel_runnable_sequence_async(langchain))

# streaming
wrap("langchain_core", "runnables.base.RunnableSequence.stream", traced_chain_stream(langchain))
wrap("langchain_core", "runnables.base.RunnableSequence.astream", traced_chain_stream(langchain))
wrap(
"langchain_core",
"language_models.chat_models.BaseChatModel.stream",
traced_chat_stream(langchain),
)
wrap(
"langchain_core",
"language_models.chat_models.BaseChatModel.astream",
traced_chat_stream(langchain),
)
wrap("langchain_core", "language_models.llms.BaseLLM.stream", traced_llm_stream(langchain))
wrap("langchain_core", "language_models.llms.BaseLLM.astream", traced_llm_stream(langchain))

wrap("langchain_core", "tools.BaseTool.invoke", traced_base_tool_invoke(langchain))
wrap("langchain_core", "tools.BaseTool.ainvoke", traced_base_tool_ainvoke(langchain))
if langchain_openai:
Expand Down Expand Up @@ -1275,6 +1455,12 @@ def unpatch():
unwrap(langchain_core.runnables.base.RunnableSequence, "ainvoke")
unwrap(langchain_core.runnables.base.RunnableSequence, "batch")
unwrap(langchain_core.runnables.base.RunnableSequence, "abatch")
unwrap(langchain_core.runnables.base.RunnableSequence, "stream")
unwrap(langchain_core.runnables.base.RunnableSequence, "astream")
unwrap(langchain_core.language_models.chat_models.BaseChatModel, "stream")
unwrap(langchain_core.language_models.chat_models.BaseChatModel, "astream")
unwrap(langchain_core.language_models.llms.BaseLLM, "stream")
unwrap(langchain_core.language_models.llms.BaseLLM, "astream")
unwrap(langchain_core.tools.BaseTool, "invoke")
unwrap(langchain_core.tools.BaseTool, "ainvoke")
if langchain_openai:
Expand Down
79 changes: 79 additions & 0 deletions ddtrace/contrib/internal/langchain/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import inspect
import sys


class BaseTracedLangChainStreamResponse:
def __init__(self, generator, integration, span, on_span_finish):
self._generator = generator
self._dd_integration = integration
self._dd_span = span
self._on_span_finish = on_span_finish
self._chunks = []


class TracedLangchainStreamResponse(BaseTracedLangChainStreamResponse):
def __iter__(self):
try:
for chunk in self._generator.__iter__():
self._chunks.append(chunk)
yield chunk
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
self._dd_integration.metric(self._dd_span, "incr", "request.error", 1)
raise
finally:
self._on_span_finish(self._dd_span, self._chunks, error=bool(self._dd_span.error))
self._dd_span.finish()


class TracedLangchainAsyncStreamResponse(BaseTracedLangChainStreamResponse):
async def __aiter__(self):
try:
async for chunk in self._generator.__aiter__():
self._chunks.append(chunk)
yield chunk
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
self._dd_integration.metric(self._dd_span, "incr", "request.error", 1)
raise
finally:
self._on_span_finish(self._dd_span, self._chunks, error=bool(self._dd_span.error))
self._dd_span.finish()


def shared_stream(
integration,
pin,
func,
instance,
args,
kwargs,
interface_type,
on_span_started,
on_span_finished,
**extra_options,
):
options = {
"pin": pin,
"operation_id": f"{instance.__module__}.{instance.__class__.__name__}",
"interface_type": interface_type,
}

options.update(extra_options)

span = integration.trace(**options)
span.set_tag("langchain.request.stream", True)
on_span_started(span)

try:
resp = func(*args, **kwargs)
cls = TracedLangchainAsyncStreamResponse if inspect.isasyncgen(resp) else TracedLangchainStreamResponse

return cls(resp, integration, span, on_span_finished)
except Exception:
# error with the method call itself
span.set_exc_info(*sys.exc_info())
span.finish()
integration.metric(span, "incr", "request.error", 1)
integration.metric(span, "dist", "request.duration", span.duration_ns)
raise
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
langchain: Adds support for tracing ``stream`` calls on LCEL chains, chat completion models, or completion models.
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
data: {"id":"chatcmpl-A75i9IDlLLFDI6n75COTQV8IsQ91c","object":"chat.completion.chunk","created":1726253613,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A75i9IDlLLFDI6n75COTQV8IsQ91c","object":"chat.completion.chunk","created":1726253613,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Python"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A75i9IDlLLFDI6n75COTQV8IsQ91c","object":"chat.completion.chunk","created":1726253613,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A75i9IDlLLFDI6n75COTQV8IsQ91c","object":"chat.completion.chunk","created":1726253613,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"\n\n"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A75i9IDlLLFDI6n75COTQV8IsQ91c","object":"chat.completion.chunk","created":1726253613,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"the"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A75i9IDlLLFDI6n75COTQV8IsQ91c","object":"chat.completion.chunk","created":1726253613,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" be"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A75i9IDlLLFDI6n75COTQV8IsQ91c","object":"chat.completion.chunk","created":1726253613,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"st!"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A75i9IDlLLFDI6n75COTQV8IsQ91c","object":"chat.completion.chunk","created":1726253613,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"length"}]}

data: [DONE]

Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"Here"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":":\n\n"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"```"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"json"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"\n"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"{\n"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":" "},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":" \""},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"countries"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"\":"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"\"France"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":" a"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":" country!\""},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"\n}"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-A87AG6I7sUCYtwCx9oIF0UK8vJLkl","object":"chat.completion.chunk","created":1726497528,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_992d1ea92d","choices":[{"index":0,"delta":{"content":"\n```"},"logprobs":null,"finish_reason":null}]}

data: [DONE]

Loading
Loading