From 90b2b9f408d0d97c9c4f9ecf3e452e2134bdbdfb Mon Sep 17 00:00:00 2001 From: Roger Yang <roger.yang@arize.com> Date: Wed, 30 Oct 2024 14:18:45 -0700 Subject: [PATCH 1/3] feat: add tool call id --- .../pyproject.toml | 6 +- .../openai/_request_attributes_extractor.py | 11 +- .../openai/_response_attributes_extractor.py | 7 + .../openai/cassettes/test_tool_calls.yaml | 45 +++++++ .../instrumentation/openai/test_tool_calls.py | 120 +++++++++++++++--- 5 files changed, 168 insertions(+), 21 deletions(-) create mode 100644 python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/cassettes/test_tool_calls.yaml diff --git a/python/instrumentation/openinference-instrumentation-openai/pyproject.toml b/python/instrumentation/openinference-instrumentation-openai/pyproject.toml index 0476d3ec2..724b098a8 100644 --- a/python/instrumentation/openinference-instrumentation-openai/pyproject.toml +++ b/python/instrumentation/openinference-instrumentation-openai/pyproject.toml @@ -28,8 +28,8 @@ dependencies = [ "opentelemetry-api", "opentelemetry-instrumentation", "opentelemetry-semantic-conventions", - "openinference-instrumentation>=0.1.17", - "openinference-semantic-conventions>=0.1.9", + "openinference-instrumentation>=0.1.18", + "openinference-semantic-conventions>=0.1.12", "typing-extensions", "wrapt", ] @@ -44,6 +44,7 @@ test = [ "opentelemetry-instrumentation-httpx", "respx", "numpy", + "pytest-vcr", ] [project.urls] @@ -63,6 +64,7 @@ packages = ["src/openinference"] [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" testpaths = [ "tests", ] diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py index 316d01360..2ff94ee61 100644 --- a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py @@ -104,7 +104,9 @@ def _get_attributes_from_message_param( MessageAttributes.MESSAGE_ROLE, role.value if isinstance(role, Enum) else role, ) - + if tool_call_id := message.get("tool_call_id"): + # https://github.com/openai/openai-python/blob/891e1c17b7fecbae34d1915ba90c15ddece807f9/src/openai/types/chat/chat_completion_tool_message_param.py#L20 + yield MessageAttributes.MESSAGE_TOOL_CALL_ID, tool_call_id if content := message.get("content"): if isinstance(content, str): yield MessageAttributes.MESSAGE_CONTENT, content @@ -140,6 +142,13 @@ def _get_attributes_from_message_param( # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call_param.py#L23 # noqa: E501 if not hasattr(tool_call, "get"): continue + if (tool_call_id := tool_call.get("id")) is not None: + # https://github.com/openai/openai-python/blob/891e1c17b7fecbae34d1915ba90c15ddece807f9/src/openai/types/chat/chat_completion_message_tool_call_param.py#L24 + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_ID}", + tool_call_id, + ) if (function := tool_call.get("function")) and hasattr(function, "get"): # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call_param.py#L10 # noqa: E501 if name := function.get("name"): diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py index 66c5bd262..7bb4cea19 100644 --- a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py @@ -195,6 +195,13 @@ def _get_attributes_from_chat_completion_message( ): # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call.py#L23 # noqa: E501 for index, tool_call in enumerate(tool_calls): + if (tool_call_id := getattr(tool_call, "id", None)) is not None: + # https://github.com/openai/openai-python/blob/891e1c17b7fecbae34d1915ba90c15ddece807f9/src/openai/types/chat/chat_completion_message_tool_call.py#L24 + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_ID}", + tool_call_id, + ) if function := getattr(tool_call, "function", None): # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call.py#L10 # noqa: E501 if name := getattr(function, "name", None): diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/cassettes/test_tool_calls.yaml b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/cassettes/test_tool_calls.yaml new file mode 100644 index 000000000..487b47e94 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/cassettes/test_tool_calls.yaml @@ -0,0 +1,45 @@ +interactions: +- request: + body: '{"messages": [{"role": "assistant", "tool_calls": [{"id": "call_62136355", + "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\": + \"New York\"}"}}, {"id": "call_62136356", "type": "function", "function": {"name": + "get_population", "arguments": "{\"city\": \"New York\"}"}}]}, {"role": "tool", + "tool_call_id": "call_62136355", "content": "{\"city\": \"New York\", \"weather\": + \"fine\"}"}, {"role": "tool", "tool_call_id": "call_62136356", "content": "{\"city\": + \"New York\", \"weather\": \"large\"}"}, {"role": "assistant", "content": "In + New York the weather is fine and the population is large."}, {"role": "user", + "content": "What''s the weather and population in San Francisco?"}], "model": + "gpt-4o-mini", "tools": [{"type": "function", "function": {"name": "get_weather", + "description": "finds the weather for a given city", "parameters": {"type": + "object", "properties": {"city": {"type": "string", "description": "The city + to find the weather for, e.g. ''London''"}}, "required": ["city"]}}}, {"type": + "function", "function": {"name": "get_population", "description": "finds the + population for a given city", "parameters": {"type": "object", "properties": + {"city": {"type": "string", "description": "The city to find the population + for, e.g. ''London''"}}, "required": ["city"]}}}]}' + headers: {} + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"id\": \"chatcmpl-AOA1TtLtlb9GcK4LnNYnNcK2ETJpW\",\n \"object\": + \"chat.completion\",\n \"created\": 1730321763,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": + \"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n + \ \"id\": \"call_Yheo99FhD2nPbVfhJ2qwM1XK\",\n \"type\": + \"function\",\n \"function\": {\n \"name\": \"get_weather\",\n + \ \"arguments\": \"{\\\"city\\\": \\\"San Francisco\\\"}\"\n }\n + \ },\n {\n \"id\": \"call_qvrUTewqYFp3XTGIYS97SCwX\",\n + \ \"type\": \"function\",\n \"function\": {\n \"name\": + \"get_population\",\n \"arguments\": \"{\\\"city\\\": \\\"San + Francisco\\\"}\"\n }\n }\n ],\n \"refusal\": + null\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n + \ }\n ],\n \"usage\": {\n \"prompt_tokens\": 207,\n \"completion_tokens\": + 46,\n \"total_tokens\": 253,\n \"prompt_tokens_details\": {\n \"cached_tokens\": + 0\n },\n \"completion_tokens_details\": {\n \"reasoning_tokens\": + 0\n }\n },\n \"system_fingerprint\": \"fp_0ba0d124f1\"\n}\n" + headers: {} + status: + code: 200 + message: OK +version: 1 diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_tool_calls.py b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_tool_calls.py index bf74c65f3..c6566b9e0 100644 --- a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_tool_calls.py +++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_tool_calls.py @@ -1,5 +1,4 @@ import json -from contextlib import suppress from importlib import import_module from importlib.metadata import version from typing import Tuple, cast @@ -9,8 +8,12 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -@pytest.mark.disable_socket -def test_tool_call( +@pytest.mark.vcr( + decode_compressed_response=True, + before_record_request=lambda _: _.headers.clear() or _, + before_record_response=lambda _: {**_, "headers": {}}, +) +def test_tool_calls( in_memory_span_exporter: InMemorySpanExporter, tracer_provider: trace_api.TracerProvider, ) -> None: @@ -58,25 +61,106 @@ def test_tool_call( }, ), ] - with suppress(openai.APIConnectionError): - client.chat.completions.create( - model="gpt-4", - tools=input_tools, - messages=[ - { - "role": "user", - "content": "What's the weather like in San Francisco?", - }, - ], - ) + client.chat.completions.create( + extra_headers={"Accept-Encoding": "gzip"}, + model="gpt-4o-mini", + tools=input_tools, + messages=[ + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_62136355", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city": "New York"}'}, + }, + { + "id": "call_62136356", + "type": "function", + "function": {"name": "get_population", "arguments": '{"city": "New York"}'}, + }, + ], + }, + { + "role": "tool", + "tool_call_id": "call_62136355", + "content": '{"city": "New York", "weather": "fine"}', + }, + { + "role": "tool", + "tool_call_id": "call_62136356", + "content": '{"city": "New York", "weather": "large"}', + }, + { + "role": "assistant", + "content": "In New York the weather is fine and the population is large.", + }, + { + "role": "user", + "content": "What's the weather and population in San Francisco?", + }, + ], + ) spans = in_memory_span_exporter.get_finished_spans() - assert len(spans) == 4 - span = spans[3] - attributes = span.attributes or dict() + assert len(spans) == 1 + span = spans[0] + attributes = dict(span.attributes or {}) for i in range(len(input_tools)): - json_schema = attributes.get(f"llm.tools.{i}.tool.json_schema") + json_schema = attributes.pop(f"llm.tools.{i}.tool.json_schema") assert isinstance(json_schema, str) assert json.loads(json_schema) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.id") == "call_62136355" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.function.name") + == "get_weather" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.function.arguments") + == '{"city": "New York"}' + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.id") == "call_62136356" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.function.name") + == "get_population" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.function.arguments") + == '{"city": "New York"}' + ) + assert attributes.pop("llm.input_messages.1.message.role") == "tool" + assert attributes.pop("llm.input_messages.1.message.tool_call_id") == "call_62136355" + assert ( + attributes.pop("llm.input_messages.1.message.content") + == '{"city": "New York", "weather": "fine"}' + ) + assert attributes.pop("llm.input_messages.2.message.role") == "tool" + assert attributes.pop("llm.input_messages.2.message.tool_call_id") == "call_62136356" + assert ( + attributes.pop("llm.input_messages.2.message.content") + == '{"city": "New York", "weather": "large"}' + ) + assert attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id") + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.function.name") + == "get_weather" + ) + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.function.arguments") + == '{"city": "San Francisco"}' + ) + assert attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.id") + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.function.name") + == "get_population" + ) + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.function.arguments") + == '{"city": "San Francisco"}' + ) def _openai_version() -> Tuple[int, int, int]: From 6baa7047f6c40f3e3395f25d8b4845cc94923c78 Mon Sep 17 00:00:00 2001 From: Roger Yang <roger.yang@arize.com> Date: Wed, 30 Oct 2024 15:02:06 -0700 Subject: [PATCH 2/3] fix tests --- .../openai/test_instrumentor.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py index 61e5af613..a32e9043c 100644 --- a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py @@ -225,6 +225,15 @@ async def task() -> None: ) # We left out model_name from our mock stream. assert attributes.pop(LLM_MODEL_NAME, None) == model_name + else: + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id") + == "call_amGrubFmr2FSPHeC5OPgwcNs" + ) + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.id") + == "call_6QTP4mLSYYzZwt3ZWj77vfZf" + ) if use_context_attributes: _check_context_attributes( attributes, @@ -651,6 +660,15 @@ async def task() -> None: ) # We left out model_name from our mock stream. assert attributes.pop(LLM_MODEL_NAME, None) == model_name + else: + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id") + == "call_amGrubFmr2FSPHeC5OPgwcNs" + ) + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.id") + == "call_6QTP4mLSYYzZwt3ZWj77vfZf" + ) if use_context_attributes: _check_context_attributes( attributes, From caa46444a9789d31d986b95ac41f1569c2f8951e Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:26:28 -0700 Subject: [PATCH 3/3] fix tests --- .../openinference/instrumentation/openai/test_instrumentor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py index a32e9043c..0e3baf40a 100644 --- a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py @@ -225,7 +225,7 @@ async def task() -> None: ) # We left out model_name from our mock stream. assert attributes.pop(LLM_MODEL_NAME, None) == model_name - else: + elif _openai_version() >= (1, 12, 0): assert ( attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id") == "call_amGrubFmr2FSPHeC5OPgwcNs" @@ -660,7 +660,7 @@ async def task() -> None: ) # We left out model_name from our mock stream. assert attributes.pop(LLM_MODEL_NAME, None) == model_name - else: + elif _openai_version() >= (1, 12, 0): assert ( attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id") == "call_amGrubFmr2FSPHeC5OPgwcNs"