Skip to content

Commit

Permalink
Merge pull request #32 from monocle2ai/kshitiz/fix_bugs
Browse files Browse the repository at this point in the history
Fixed Llama index bugs
  • Loading branch information
kshitiz-okahu authored Aug 27, 2024
2 parents 253396b + 5f242a5 commit 8556cc6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/monocle_apptrace/instrumentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

logger = logging.getLogger(__name__)

_instruments = ("langchain >= 0.0.346",)
_instruments = ()

class MonocleInstrumentor(BaseInstrumentor):

Expand Down
23 changes: 14 additions & 9 deletions src/monocle_apptrace/wrap_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def get_input_from_args(chain_args):
return ""

def update_span_from_llm_response(response, span: Span):

# extract token uasge from langchain openai
if (response is not None and hasattr(response, "response_metadata")):
response_metadata = response.response_metadata
Expand All @@ -201,15 +202,19 @@ def update_span_from_llm_response(response, span: Span):
span.set_attribute("total_tokens", token_usage.get("total_tokens"))
# extract token usage from llamaindex openai
if(response is not None and hasattr(response, "raw")):
if response.raw is not None:
token_usage = response.raw.get("usage")
if token_usage is not None:
if hasattr(token_usage, "completion_tokens"):
span.set_attribute("completion_tokens", token_usage.completion_tokens)
if hasattr(token_usage, "prompt_tokens"):
span.set_attribute("prompt_tokens", token_usage.prompt_tokens)
if hasattr(token_usage, "total_tokens"):
span.set_attribute("total_tokens", token_usage.total_tokens)
try:
if response.raw is not None:
token_usage = response.raw.get("usage") if isinstance(response.raw, dict) else getattr(response.raw, "usage", None)
if token_usage is not None:
if getattr(token_usage, "completion_tokens", None):
span.set_attribute("completion_tokens", getattr(token_usage, "completion_tokens"))
if getattr(token_usage, "prompt_tokens", None):
span.set_attribute("prompt_tokens", getattr(token_usage, "prompt_tokens"))
if getattr(token_usage, "total_tokens", None):
span.set_attribute("total_tokens", getattr(token_usage, "total_tokens"))
except AttributeError:
token_usage = None


def update_workflow_type(to_wrap, span: Span):
package_name = to_wrap.get('package')
Expand Down
21 changes: 20 additions & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
)
from llama_index.core.llms.callbacks import llm_completion_callback

class TestChatCompletion:

def __init__(self, usage):
self.usage = usage

class TestCompletionUsage:

def __init__(self, completion_tokens, prompt_tokens, total_tokens):
self.completion_tokens = completion_tokens
self.prompt_tokens = prompt_tokens
self.total_tokens = total_tokens

class OurLLM(CustomLLM):
context_window: int = 3900
Expand All @@ -28,7 +39,15 @@ def metadata(self) -> LLMMetadata:

@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
return CompletionResponse(text=self.dummy_response)
return CompletionResponse(
text=self.dummy_response,
raw= {
"usage": TestCompletionUsage(
completion_tokens=1,
prompt_tokens = 2,
total_tokens=3
)
})

@llm_completion_callback()
def stream_complete(
Expand Down
11 changes: 8 additions & 3 deletions tests/llama_index_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
llm_wrapper,
)
from monocle_apptrace.wrapper import WrapperMethod
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter

logger = logging.getLogger(__name__)

Expand All @@ -55,7 +55,8 @@ def test_llama_index(self, mock_post):
setup_monocle_telemetry(
workflow_name="llama_index_1",
span_processors=[
BatchSpanProcessor(HttpSpanExporter(os.environ["HTTP_INGESTION_ENDPOINT"]))
BatchSpanProcessor(HttpSpanExporter(os.environ["HTTP_INGESTION_ENDPOINT"])),
BatchSpanProcessor(ConsoleSpanExporter())
],
wrapper_methods=[
WrapperMethod(
Expand Down Expand Up @@ -119,9 +120,13 @@ def get_event_attributes(events, key):
assert output_event_attributes[RESPONSE] == llm.dummy_response

span_names: List[str] = [span["name"] for span in dataJson['batch']]
llm_span = [x for x in dataJson["batch"] if "llamaindex.OurLLM" in x["name"]][0]
for name in ["llamaindex.retrieve", "llamaindex.query", "llamaindex.OurLLM"]:
assert name in span_names

assert llm_span["attributes"]["completion_tokens"] == 1
assert llm_span["attributes"]["prompt_tokens"] == 2
assert llm_span["attributes"]["total_tokens"] == 3

type_found = False
model_name_found = False

Expand Down

0 comments on commit 8556cc6

Please sign in to comment.