Skip to content

Commit

Permalink
fix(llmobs): avoid raising errors during llmobs integration span proc…
Browse files Browse the repository at this point in the history
…essing (#10713)

This PR does 2 things:


### User facing changes
- captures any integration-specific `_llmobs_set_tags()` method errors
and logs the error instead of potentially crashing the user application.

### Non-user facing changes
Refactors the `BaseLLMIntegration` class and child classes to follow a
cleaner and shared `llmobs_set_tags()` method, which internally
try/catches an abstract method `_llmobs_set_tags()` instead (which is
implemented by each integration). We also no longer need to check
`integration.is_pc_sampled_llmobs(span)` since we don't currently do any
sampling yet and we can handle it in the `llmobs_set_tags()` method if
needed.
tldr: `_llmobs_set_tags()` is now an abstract method that needs to be
implemented by all LLM integrations, and its function signature now
takes in the following arguments/keyword arguments (same as
`llmobs_set_tags()`):
- span: span to annotate
- args: list of args passed to the traced method
- kwargs: dict of keyword args passed to the traced method. If any
integration requires additional data not contained by either args/kwargs
(such as the model instance in Gemini or tool_input dictionary in
langchain), we can pass it into the method using the kwarg dict.
- response: returned response from llm provider (streamed or
non-streamed)
- operation: string denoting which LLM operation it is (eg.
"completion", "chat", "embedding", "chain", "retrieval")

I did some refactoring to each integration to follow this new signature,
which included merging logic for how we handle streamed responses, and
additional required args (i.e. model instance, tool inputs).

Previously each integration did its own thing for `llmobs_set_tags()`
with arbitrary args/kwargs, and it was difficult to maintain. Now that
we have a strict function signature, future integrations should be
simpler to create, and existing integrations should be easier to
maintain.

## Checklist
- [x] PR author has checked that all the criteria below are met
- The PR description includes an overview of the change
- The PR description articulates the motivation for the change
- The change includes tests OR the PR description describes a testing
strategy
- The PR description notes risks associated with the change, if any
- Newly-added code is easy to change
- The change follows the [library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
- The change includes or references documentation updates if necessary
- Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))

## Reviewer Checklist
- [x] Reviewer has checked that all the criteria below are met 
- Title is accurate
- All changes are related to the pull request's stated goal
- Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- Testing strategy adequately addresses listed risks
- Newly-added code is easy to change
- Release note makes sense to a user of the library
- If necessary, author has acknowledged and discussed the performance
implications of this PR as reported in the benchmarks PR comment
- Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
  • Loading branch information
Yun-Kim committed Sep 23, 2024
1 parent 1e8ba1c commit cf6f007
Show file tree
Hide file tree
Showing 16 changed files with 286 additions and 333 deletions.
8 changes: 3 additions & 5 deletions ddtrace/_trace/trace_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,11 +687,10 @@ def _on_botocore_patched_bedrock_api_call_started(ctx, request_params):
def _on_botocore_patched_bedrock_api_call_exception(ctx, exc_info):
span = ctx[ctx["call_key"]]
span.set_exc_info(*exc_info)
prompt = ctx["prompt"]
model_name = ctx["model_name"]
integration = ctx["bedrock_integration"]
if integration.is_pc_sampled_llmobs(span) and "embed" not in model_name:
integration.llmobs_set_tags(span, formatted_response=None, prompt=prompt, err=True)
if "embed" not in model_name:
integration.llmobs_set_tags(span, args=[], kwargs={"prompt": ctx["prompt"]})
span.finish()


Expand Down Expand Up @@ -733,8 +732,7 @@ def _on_botocore_bedrock_process_response(
span.set_tag_str(
"bedrock.response.choices.{}.finish_reason".format(i), str(formatted_response["finish_reason"][i])
)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span, formatted_response=formatted_response, prompt=ctx["prompt"])
integration.llmobs_set_tags(span, args=[], kwargs={"prompt": ctx["prompt"]}, response=formatted_response)
span.finish()


Expand Down
9 changes: 1 addition & 8 deletions ddtrace/contrib/internal/anthropic/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,9 @@ def _process_finished_stream(integration, span, args, kwargs, streamed_chunks):
# builds the response message given streamed chunks and sets according span tags
try:
resp_message = _construct_message(streamed_chunks)

if integration.is_pc_sampled_span(span):
_tag_streamed_chat_completion_response(integration, span, resp_message)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
span=span,
resp=resp_message,
args=args,
kwargs=kwargs,
)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp_message)
except Exception:
log.warning("Error processing streamed completion/chat response.", exc_info=True)

Expand Down
6 changes: 2 additions & 4 deletions ddtrace/contrib/internal/anthropic/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def traced_chat_model_generate(anthropic, pin, func, instance, args, kwargs):
finally:
# we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=chat_completions)
span.finish()
return chat_completions

Expand Down Expand Up @@ -178,8 +177,7 @@ async def traced_async_chat_model_generate(anthropic, pin, func, instance, args,
finally:
# we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=chat_completions)
span.finish()
return chat_completions

Expand Down
22 changes: 14 additions & 8 deletions ddtrace/contrib/internal/google_generativeai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ def __iter__(self):
else:
tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance)
finally:
if self._dd_integration.is_pc_sampled_llmobs(self._dd_span):
self._dd_integration.llmobs_set_tags(
self._dd_span, self._args, self._kwargs, self._model_instance, self.__wrapped__
)
self._kwargs["instance"] = self._model_instance
self._dd_integration.llmobs_set_tags(
self._dd_span,
args=self._args,
kwargs=self._kwargs,
response=self.__wrapped__,
)
self._dd_span.finish()


Expand All @@ -48,10 +51,13 @@ async def __aiter__(self):
else:
tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance)
finally:
if self._dd_integration.is_pc_sampled_llmobs(self._dd_span):
self._dd_integration.llmobs_set_tags(
self._dd_span, self._args, self._kwargs, self._model_instance, self.__wrapped__
)
self._kwargs["instance"] = self._model_instance
self._dd_integration.llmobs_set_tags(
self._dd_span,
args=self._args,
kwargs=self._kwargs,
response=self.__wrapped__,
)
self._dd_span.finish()


Expand Down
8 changes: 4 additions & 4 deletions ddtrace/contrib/internal/google_generativeai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def traced_generate(genai, pin, func, instance, args, kwargs):
finally:
# streamed spans will be finished separately once the stream generator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span, args, kwargs, instance, generations)
kwargs["instance"] = instance
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=generations)
span.finish()
return generations

Expand Down Expand Up @@ -90,8 +90,8 @@ async def traced_agenerate(genai, pin, func, instance, args, kwargs):
finally:
# streamed spans will be finished separately once the stream generator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span, args, kwargs, instance, generations)
kwargs["instance"] = instance
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=generations)
span.finish()
return generations

Expand Down
94 changes: 14 additions & 80 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,7 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"llm",
span,
prompts,
completions,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=completions, operation="llm")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -322,14 +315,7 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"llm",
span,
prompts,
completions,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=completions, operation="llm")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -438,14 +424,7 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"chat",
span,
chat_messages,
chat_completions,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=chat_completions, operation="chat")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -570,14 +549,7 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"chat",
span,
chat_messages,
chat_completions,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=chat_completions, operation="chat")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -662,14 +634,7 @@ def traced_embedding(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"embedding",
span,
input_texts,
embeddings,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=embeddings, operation="embedding")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -717,8 +682,7 @@ def traced_chain_call(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_outputs, error=bool(span.error))
integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_outputs, operation="chain")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -774,8 +738,7 @@ async def traced_chain_acall(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_outputs, error=bool(span.error))
integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_outputs, operation="chain")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -847,8 +810,7 @@ def traced_lcel_runnable_sequence(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_output, error=bool(span.error))
integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_output, operation="chain")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
return final_output
Expand Down Expand Up @@ -894,8 +856,7 @@ async def traced_lcel_runnable_sequence_async(langchain, pin, func, instance, ar
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_output, error=bool(span.error))
integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_output, operation="chain")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
return final_output
Expand Down Expand Up @@ -953,14 +914,7 @@ def traced_similarity_search(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"retrieval",
span,
query,
documents,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=documents, operation="retrieval")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -1024,18 +978,8 @@ def traced_base_tool_invoke(langchain, pin, func, instance, args, kwargs):
span.set_exc_info(*sys.exc_info())
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"tool",
span,
{
"input": tool_input,
"config": config if config else {},
"info": tool_info if tool_info else {},
},
tool_output,
error=bool(span.error),
)
tool_inputs = {"input": tool_input, "config": config or {}, "info": tool_info or {}}
integration.llmobs_set_tags(span, args=[], kwargs=tool_inputs, response=tool_output, operation="tool")
span.finish()
return tool_output

Expand Down Expand Up @@ -1085,18 +1029,8 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
span.set_exc_info(*sys.exc_info())
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"tool",
span,
{
"input": tool_input,
"config": config if config else {},
"info": tool_info if tool_info else {},
},
tool_output,
error=bool(span.error),
)
tool_inputs = {"input": tool_input, "config": config or {}, "info": tool_info or {}}
integration.llmobs_set_tags(span, args=[], kwargs=tool_inputs, response=tool_output, operation="tool")
span.finish()
return tool_output

Expand Down
14 changes: 3 additions & 11 deletions ddtrace/contrib/internal/openai/_endpoint_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ddtrace.contrib.internal.openai.utils import _process_finished_stream
from ddtrace.contrib.internal.openai.utils import _tag_tool_calls
from ddtrace.internal.utils.version import parse_version
from ddtrace.llmobs._constants import SPAN_KIND


API_VERSION = "v1"
Expand Down Expand Up @@ -189,8 +188,6 @@ class _CompletionHook(_BaseCompletionHook):

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
if integration.is_pc_sampled_llmobs(span):
span.set_tag_str(SPAN_KIND, "llm")
if integration.is_pc_sampled_span(span):
prompt = kwargs.get("prompt", "")
if isinstance(prompt, str):
Expand All @@ -212,8 +209,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
integration.log(
span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict
)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("completion", resp, span, kwargs, err=error)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="completion")
if not resp:
return
for choice in resp.choices:
Expand Down Expand Up @@ -247,8 +243,6 @@ class _ChatCompletionHook(_BaseCompletionHook):

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
if integration.is_pc_sampled_llmobs(span):
span.set_tag_str(SPAN_KIND, "llm")
for idx, m in enumerate(kwargs.get("messages", [])):
role = getattr(m, "role", "")
name = getattr(m, "name", "")
Expand All @@ -274,8 +268,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
integration.log(
span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict
)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chat", resp, span, kwargs, err=error)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="chat")
if not resp:
return
for choice in resp.choices:
Expand Down Expand Up @@ -319,8 +312,7 @@ def _record_request(self, pin, integration, span, args, kwargs):

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("embedding", resp, span, kwargs, err=error)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="embedding")
if not resp:
return
span.set_metric("openai.response.embeddings_count", len(resp.data))
Expand Down
6 changes: 2 additions & 4 deletions ddtrace/contrib/internal/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,8 @@ def _process_finished_stream(integration, span, kwargs, streamed_chunks, is_comp
if integration.is_pc_sampled_span(span):
_tag_streamed_response(integration, span, formatted_completions)
_set_token_metrics(span, integration, formatted_completions, prompts, request_messages, kwargs)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"completion" if is_completion else "chat", None, span, kwargs, formatted_completions, None
)
operation = "completion" if is_completion else "chat"
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=formatted_completions, operation=operation)
except Exception:
log.warning("Error processing streamed completion/chat response.", exc_info=True)

Expand Down
Loading

0 comments on commit cf6f007

Please sign in to comment.