diff --git a/src/monocle_apptrace/instrumentor.py b/src/monocle_apptrace/instrumentor.py index 84b07e3..1ba6b70 100644 --- a/src/monocle_apptrace/instrumentor.py +++ b/src/monocle_apptrace/instrumentor.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -_instruments = ("langchain >= 0.0.346",) +_instruments = () class MonocleInstrumentor(BaseInstrumentor): diff --git a/src/monocle_apptrace/wrap_common.py b/src/monocle_apptrace/wrap_common.py index bcc785a..c7b6fea 100644 --- a/src/monocle_apptrace/wrap_common.py +++ b/src/monocle_apptrace/wrap_common.py @@ -191,6 +191,7 @@ def get_input_from_args(chain_args): return "" def update_span_from_llm_response(response, span: Span): + # extract token uasge from langchain openai if (response is not None and hasattr(response, "response_metadata")): response_metadata = response.response_metadata @@ -201,15 +202,19 @@ def update_span_from_llm_response(response, span: Span): span.set_attribute("total_tokens", token_usage.get("total_tokens")) # extract token usage from llamaindex openai if(response is not None and hasattr(response, "raw")): - if response.raw is not None: - token_usage = response.raw.get("usage") - if token_usage is not None: - if hasattr(token_usage, "completion_tokens"): - span.set_attribute("completion_tokens", token_usage.completion_tokens) - if hasattr(token_usage, "prompt_tokens"): - span.set_attribute("prompt_tokens", token_usage.prompt_tokens) - if hasattr(token_usage, "total_tokens"): - span.set_attribute("total_tokens", token_usage.total_tokens) + try: + if response.raw is not None: + token_usage = response.raw.get("usage") if isinstance(response.raw, dict) else getattr(response.raw, "usage", None) + if token_usage is not None: + if getattr(token_usage, "completion_tokens", None): + span.set_attribute("completion_tokens", getattr(token_usage, "completion_tokens")) + if getattr(token_usage, "prompt_tokens", None): + span.set_attribute("prompt_tokens", getattr(token_usage, "prompt_tokens")) + if getattr(token_usage, "total_tokens", None): + span.set_attribute("total_tokens", getattr(token_usage, "total_tokens")) + except AttributeError: + token_usage = None + def update_workflow_type(to_wrap, span: Span): package_name = to_wrap.get('package') diff --git a/tests/helpers.py b/tests/helpers.py index c2d6c92..27f366f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -10,6 +10,17 @@ ) from llama_index.core.llms.callbacks import llm_completion_callback +class TestChatCompletion: + + def __init__(self, usage): + self.usage = usage + +class TestCompletionUsage: + + def __init__(self, completion_tokens, prompt_tokens, total_tokens): + self.completion_tokens = completion_tokens + self.prompt_tokens = prompt_tokens + self.total_tokens = total_tokens class OurLLM(CustomLLM): context_window: int = 3900 @@ -28,7 +39,15 @@ def metadata(self) -> LLMMetadata: @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - return CompletionResponse(text=self.dummy_response) + return CompletionResponse( + text=self.dummy_response, + raw= { + "usage": TestCompletionUsage( + completion_tokens=1, + prompt_tokens = 2, + total_tokens=3 + ) + }) @llm_completion_callback() def stream_complete( diff --git a/tests/llama_index_test.py b/tests/llama_index_test.py index 4279e58..95151a7 100644 --- a/tests/llama_index_test.py +++ b/tests/llama_index_test.py @@ -28,7 +28,7 @@ llm_wrapper, ) from monocle_apptrace.wrapper import WrapperMethod -from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter logger = logging.getLogger(__name__) @@ -55,7 +55,8 @@ def test_llama_index(self, mock_post): setup_monocle_telemetry( workflow_name="llama_index_1", span_processors=[ - BatchSpanProcessor(HttpSpanExporter(os.environ["HTTP_INGESTION_ENDPOINT"])) + BatchSpanProcessor(HttpSpanExporter(os.environ["HTTP_INGESTION_ENDPOINT"])), + BatchSpanProcessor(ConsoleSpanExporter()) ], wrapper_methods=[ WrapperMethod( @@ -119,9 +120,13 @@ def get_event_attributes(events, key): assert output_event_attributes[RESPONSE] == llm.dummy_response span_names: List[str] = [span["name"] for span in dataJson['batch']] + llm_span = [x for x in dataJson["batch"] if "llamaindex.OurLLM" in x["name"]][0] for name in ["llamaindex.retrieve", "llamaindex.query", "llamaindex.OurLLM"]: assert name in span_names - + assert llm_span["attributes"]["completion_tokens"] == 1 + assert llm_span["attributes"]["prompt_tokens"] == 2 + assert llm_span["attributes"]["total_tokens"] == 3 + type_found = False model_name_found = False