Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(llama-index): capture tool calls from anthropic chat response #1177

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ test = [
"llama-index == 0.11.0",
"llama-index-core >= 0.11.0",
"llama-index-llms-openai",
"llama_index.llms.anthropic",
"llama-index-llms-groq",
"pytest-vcr",
"anthropic<0.41",
"llama-index-multi-modal-llms-openai>=0.1.7",
"openinference-instrumentation-openai",
"opentelemetry-sdk",
Expand Down Expand Up @@ -86,6 +88,8 @@ exclude = [
ignore_missing_imports = true
module = [
"wrapt",
"llama_index.llms.anthropic",
"llama_index.llms.openai",
]

[tool.ruff]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,19 @@ def handle(self, event: BaseEvent, **kwargs: Any) -> Any:


def _get_tool_call(tool_call: object) -> Iterator[Tuple[str, Any]]:
if function := getattr(tool_call, "function", None):
if isinstance(tool_call, dict):
if tool_call_id := tool_call.get("id"):
yield TOOL_CALL_ID, tool_call_id
if name := tool_call.get("name"):
yield TOOL_CALL_FUNCTION_NAME, name
if arguments := tool_call.get("input"):
if isinstance(arguments, str):
yield TOOL_CALL_FUNCTION_ARGUMENTS_JSON, arguments
elif isinstance(arguments, dict):
yield TOOL_CALL_FUNCTION_ARGUMENTS_JSON, safe_json_dumps(arguments)
elif function := getattr(tool_call, "function", None):
if tool_call_id := getattr(tool_call, "id", None):
yield TOOL_CALL_ID, tool_call_id
if name := getattr(function, "name", None):
yield TOOL_CALL_FUNCTION_NAME, name
if arguments := getattr(function, "arguments", None):
Expand Down Expand Up @@ -1032,6 +1044,7 @@ def is_base64_url(url: str) -> bool:
RERANKER_QUERY = RerankerAttributes.RERANKER_QUERY
RERANKER_TOP_K = RerankerAttributes.RERANKER_TOP_K
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
TOOL_DESCRIPTION = SpanAttributes.TOOL_DESCRIPTION
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
interactions:
- request:
body: '{"max_tokens":512,"messages":[{"role":"user","content":[{"text":"what''s
the weather in San Francisco?","type":"text"}]}],"model":"claude-3-5-haiku-20241022","stream":false,"system":"","temperature":0.1,"tools":[{"name":"get_weather","description":"get_weather(location:
str) -> str\nUseful for getting the weather for a given location.","input_schema":{"properties":{"location":{"title":"Location","type":"string"}},"required":["location"],"type":"object"}}]}'
headers: {}
method: POST
uri: https://api.anthropic.com/v1/messages
response:
body:
string: '{"id":"msg_011UbtsepYnQzWFNg8fDmFZ2","type":"message","role":"assistant","model":"claude-3-5-haiku-20241022","content":[{"type":"text","text":"I''ll
help you check the weather in San Francisco right away."},{"type":"tool_use","id":"toolu_01P7dMjNQjMNZK8BB8sKP25k","name":"get_weather","input":{"location":"San
Francisco"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":355,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":68}}'
headers: {}
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
interactions:
- request:
body: '{"messages":[{"role":"user","content":"what''s the weather in San Francisco?"}],"model":"gpt-4o-mini","stream":false,"temperature":0.1,"tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"get_weather(location:
str) -> str\nUseful for getting the weather for a given location.","parameters":{"properties":{"location":{"title":"Location","type":"string"}},"required":["location"],"type":"object","additionalProperties":false},"strict":false}}]}'
headers: {}
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: "{\n \"id\": \"chatcmpl-AgcDNYhR5NPYhy2hmtnkm6CP8GFAN\",\n \"object\":
\"chat.completion\",\n \"created\": 1734720037,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
\ \"id\": \"call_FjpIANozIfaXzuQQnmhK0yD3\",\n \"type\":
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
\ \"arguments\": \"{\\\"location\\\":\\\"San Francisco\\\"}\"\n
\ }\n }\n ],\n \"refusal\": null\n },\n
\ \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n }\n
\ ],\n \"usage\": {\n \"prompt_tokens\": 68,\n \"completion_tokens\":
16,\n \"total_tokens\": 84,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\":
\"fp_0aa8d3e20b\"\n}\n"
headers: {}
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from llama_index.core.base.response.schema import StreamingResponse
from llama_index.core.callbacks import CallbackManager
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI # type: ignore
from llama_index.llms.openai import OpenAI
from opentelemetry import trace as trace_api
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace import ReadableSpan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from llama_index.core import Document, ListIndex, Settings
from llama_index.core.callbacks import CallbackManager
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI # type: ignore
from llama_index.llms.openai import OpenAI
from opentelemetry import trace as trace_api
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace import ReadableSpan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from httpx import Response
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.multi_modal_llms.generic_utils import load_image_urls
from llama_index.llms.openai import OpenAI # type: ignore
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal # type: ignore
from llama_index.multi_modal_llms.openai import utils as openai_utils
from opentelemetry import trace as trace_api
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.multi_modal_llms.generic_utils import load_image_urls
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI # type: ignore
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal # type: ignore
from llama_index.multi_modal_llms.openai import utils as openai_utils
from opentelemetry import trace as trace_api
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from importlib.metadata import version
from json import loads
from typing import Iterator, Tuple, cast

import pytest
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.tools import FunctionTool
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.openai import OpenAI
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace import TracerProvider

from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
from openinference.semconv.trace import MessageAttributes, SpanAttributes, ToolCallAttributes

LLAMA_INDEX_LLMS_OPENAI_VERSION = cast(
Tuple[int, int], tuple(map(int, version("llama_index.llms.openai").split(".")[:2]))
)
LLAMA_INDEX_LLMS_ANTHROPIC_VERSION = cast(
Tuple[int, int], tuple(map(int, version("llama_index.llms.anthropic").split(".")[:2]))
)


def get_weather(location: str) -> str:
"""Useful for getting the weather for a given location."""
raise NotImplementedError


TOOL = FunctionTool.from_defaults(get_weather)


class TestToolCallsInChatResponse:
@pytest.mark.skipif(
LLAMA_INDEX_LLMS_OPENAI_VERSION < (0, 3),
reason="ignore older versions to simplify test upkeep",
)
@pytest.mark.vcr(
decode_compressed_response=True,
before_record_request=lambda _: _.headers.clear() or _,
before_record_response=lambda _: {**_, "headers": {}},
)
async def test_openai(
self,
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
llm = OpenAI(model="gpt-4o-mini", api_key="sk-")
await self._test(llm, in_memory_span_exporter)

@pytest.mark.skipif(
LLAMA_INDEX_LLMS_ANTHROPIC_VERSION < (0, 6),
reason="ignore older versions to simplify test upkeep",
)
@pytest.mark.vcr(
decode_compressed_response=True,
before_record_request=lambda _: _.headers.clear() or _,
before_record_response=lambda _: {**_, "headers": {}},
)
async def test_anthropic(
self,
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
llm = Anthropic(model="claude-3-5-haiku-20241022", api_key="sk-")
await self._test(llm, in_memory_span_exporter)

@classmethod
async def _test(
cls,
llm: FunctionCallingLLM,
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
await llm.achat(
**llm._prepare_chat_with_tools([TOOL], "what's the weather in San Francisco?"),
)
spans = in_memory_span_exporter.get_finished_spans()
span = spans[-1]
assert span.attributes
assert span.attributes.get(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_ID}")
assert (
span.attributes.get(
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_FUNCTION_NAME}"
)
== "get_weather"
)
assert isinstance(
arguments := span.attributes.get(
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}"
),
str,
)
assert loads(arguments) == {"location": "San Francisco"}


@pytest.fixture(autouse=True)
def instrument(
tracer_provider: TracerProvider,
in_memory_span_exporter: InMemorySpanExporter,
) -> Iterator[None]:
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
yield
LlamaIndexInstrumentor().uninstrument()


LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
MESSAGE_TOOL_CALL_ID = MessageAttributes.MESSAGE_TOOL_CALL_ID
TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
2 changes: 1 addition & 1 deletion python/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ commands_pre =
vertexai: uv pip install --reinstall {toxinidir}/instrumentation/openinference-instrumentation-vertexai[test]
vertexai-latest: uv pip install -U vertexai 'httpx<0.28'
llama_index: uv pip install --reinstall {toxinidir}/instrumentation/openinference-instrumentation-llama-index[test] 'httpx<0.28'
llama_index-latest: uv pip install -U llama-index llama-index-core 'httpx<0.28'
llama_index-latest: uv pip install -U llama-index llama-index-core llama-index-llms-openai openai llama-index-llms-anthropic anthropic 'httpx<0.28'
dspy: uv pip install --reinstall {toxinidir}/instrumentation/openinference-instrumentation-dspy[test]
dspy-latest: uv pip install -U dspy-ai 'httpx<0.28'
langchain: uv pip install --reinstall {toxinidir}/instrumentation/openinference-instrumentation-langchain[test]
Expand Down
Loading