From 90b2b9f408d0d97c9c4f9ecf3e452e2134bdbdfb Mon Sep 17 00:00:00 2001
From: Roger Yang <roger.yang@arize.com>
Date: Wed, 30 Oct 2024 14:18:45 -0700
Subject: [PATCH 1/3] feat: add tool call id

---
 .../pyproject.toml                            |   6 +-
 .../openai/_request_attributes_extractor.py   |  11 +-
 .../openai/_response_attributes_extractor.py  |   7 +
 .../openai/cassettes/test_tool_calls.yaml     |  45 +++++++
 .../instrumentation/openai/test_tool_calls.py | 120 +++++++++++++++---
 5 files changed, 168 insertions(+), 21 deletions(-)
 create mode 100644 python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/cassettes/test_tool_calls.yaml

diff --git a/python/instrumentation/openinference-instrumentation-openai/pyproject.toml b/python/instrumentation/openinference-instrumentation-openai/pyproject.toml
index 0476d3ec2..724b098a8 100644
--- a/python/instrumentation/openinference-instrumentation-openai/pyproject.toml
+++ b/python/instrumentation/openinference-instrumentation-openai/pyproject.toml
@@ -28,8 +28,8 @@ dependencies = [
   "opentelemetry-api",
   "opentelemetry-instrumentation",
   "opentelemetry-semantic-conventions",
-  "openinference-instrumentation>=0.1.17",
-  "openinference-semantic-conventions>=0.1.9",
+  "openinference-instrumentation>=0.1.18",
+  "openinference-semantic-conventions>=0.1.12",
   "typing-extensions",
   "wrapt",
 ]
@@ -44,6 +44,7 @@ test = [
   "opentelemetry-instrumentation-httpx",
   "respx",
   "numpy",
+  "pytest-vcr",
 ]
 
 [project.urls]
@@ -63,6 +64,7 @@ packages = ["src/openinference"]
 
 [tool.pytest.ini_options]
 asyncio_mode = "auto"
+asyncio_default_fixture_loop_scope = "function"
 testpaths = [
   "tests",
 ]
diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py
index 316d01360..2ff94ee61 100644
--- a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py
+++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py
@@ -104,7 +104,9 @@ def _get_attributes_from_message_param(
                 MessageAttributes.MESSAGE_ROLE,
                 role.value if isinstance(role, Enum) else role,
             )
-
+        if tool_call_id := message.get("tool_call_id"):
+            # https://github.com/openai/openai-python/blob/891e1c17b7fecbae34d1915ba90c15ddece807f9/src/openai/types/chat/chat_completion_tool_message_param.py#L20
+            yield MessageAttributes.MESSAGE_TOOL_CALL_ID, tool_call_id
         if content := message.get("content"):
             if isinstance(content, str):
                 yield MessageAttributes.MESSAGE_CONTENT, content
@@ -140,6 +142,13 @@ def _get_attributes_from_message_param(
                 # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call_param.py#L23  # noqa: E501
                 if not hasattr(tool_call, "get"):
                     continue
+                if (tool_call_id := tool_call.get("id")) is not None:
+                    # https://github.com/openai/openai-python/blob/891e1c17b7fecbae34d1915ba90c15ddece807f9/src/openai/types/chat/chat_completion_message_tool_call_param.py#L24
+                    yield (
+                        f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
+                        f"{ToolCallAttributes.TOOL_CALL_ID}",
+                        tool_call_id,
+                    )
                 if (function := tool_call.get("function")) and hasattr(function, "get"):
                     # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call_param.py#L10  # noqa: E501
                     if name := function.get("name"):
diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py
index 66c5bd262..7bb4cea19 100644
--- a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py
+++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py
@@ -195,6 +195,13 @@ def _get_attributes_from_chat_completion_message(
         ):
             # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call.py#L23  # noqa: E501
             for index, tool_call in enumerate(tool_calls):
+                if (tool_call_id := getattr(tool_call, "id", None)) is not None:
+                    # https://github.com/openai/openai-python/blob/891e1c17b7fecbae34d1915ba90c15ddece807f9/src/openai/types/chat/chat_completion_message_tool_call.py#L24
+                    yield (
+                        f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
+                        f"{ToolCallAttributes.TOOL_CALL_ID}",
+                        tool_call_id,
+                    )
                 if function := getattr(tool_call, "function", None):
                     # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call.py#L10  # noqa: E501
                     if name := getattr(function, "name", None):
diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/cassettes/test_tool_calls.yaml b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/cassettes/test_tool_calls.yaml
new file mode 100644
index 000000000..487b47e94
--- /dev/null
+++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/cassettes/test_tool_calls.yaml
@@ -0,0 +1,45 @@
+interactions:
+- request:
+    body: '{"messages": [{"role": "assistant", "tool_calls": [{"id": "call_62136355",
+      "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":
+      \"New York\"}"}}, {"id": "call_62136356", "type": "function", "function": {"name":
+      "get_population", "arguments": "{\"city\": \"New York\"}"}}]}, {"role": "tool",
+      "tool_call_id": "call_62136355", "content": "{\"city\": \"New York\", \"weather\":
+      \"fine\"}"}, {"role": "tool", "tool_call_id": "call_62136356", "content": "{\"city\":
+      \"New York\", \"weather\": \"large\"}"}, {"role": "assistant", "content": "In
+      New York the weather is fine and the population is large."}, {"role": "user",
+      "content": "What''s the weather and population in San Francisco?"}], "model":
+      "gpt-4o-mini", "tools": [{"type": "function", "function": {"name": "get_weather",
+      "description": "finds the weather for a given city", "parameters": {"type":
+      "object", "properties": {"city": {"type": "string", "description": "The city
+      to find the weather for, e.g. ''London''"}}, "required": ["city"]}}}, {"type":
+      "function", "function": {"name": "get_population", "description": "finds the
+      population for a given city", "parameters": {"type": "object", "properties":
+      {"city": {"type": "string", "description": "The city to find the population
+      for, e.g. ''London''"}}, "required": ["city"]}}}]}'
+    headers: {}
+    method: POST
+    uri: https://api.openai.com/v1/chat/completions
+  response:
+    body:
+      string: "{\n  \"id\": \"chatcmpl-AOA1TtLtlb9GcK4LnNYnNcK2ETJpW\",\n  \"object\":
+        \"chat.completion\",\n  \"created\": 1730321763,\n  \"model\": \"gpt-4o-mini-2024-07-18\",\n
+        \ \"choices\": [\n    {\n      \"index\": 0,\n      \"message\": {\n        \"role\":
+        \"assistant\",\n        \"content\": null,\n        \"tool_calls\": [\n          {\n
+        \           \"id\": \"call_Yheo99FhD2nPbVfhJ2qwM1XK\",\n            \"type\":
+        \"function\",\n            \"function\": {\n              \"name\": \"get_weather\",\n
+        \             \"arguments\": \"{\\\"city\\\": \\\"San Francisco\\\"}\"\n            }\n
+        \         },\n          {\n            \"id\": \"call_qvrUTewqYFp3XTGIYS97SCwX\",\n
+        \           \"type\": \"function\",\n            \"function\": {\n              \"name\":
+        \"get_population\",\n              \"arguments\": \"{\\\"city\\\": \\\"San
+        Francisco\\\"}\"\n            }\n          }\n        ],\n        \"refusal\":
+        null\n      },\n      \"logprobs\": null,\n      \"finish_reason\": \"tool_calls\"\n
+        \   }\n  ],\n  \"usage\": {\n    \"prompt_tokens\": 207,\n    \"completion_tokens\":
+        46,\n    \"total_tokens\": 253,\n    \"prompt_tokens_details\": {\n      \"cached_tokens\":
+        0\n    },\n    \"completion_tokens_details\": {\n      \"reasoning_tokens\":
+        0\n    }\n  },\n  \"system_fingerprint\": \"fp_0ba0d124f1\"\n}\n"
+    headers: {}
+    status:
+      code: 200
+      message: OK
+version: 1
diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_tool_calls.py b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_tool_calls.py
index bf74c65f3..c6566b9e0 100644
--- a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_tool_calls.py
+++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_tool_calls.py
@@ -1,5 +1,4 @@
 import json
-from contextlib import suppress
 from importlib import import_module
 from importlib.metadata import version
 from typing import Tuple, cast
@@ -9,8 +8,12 @@
 from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
 
 
-@pytest.mark.disable_socket
-def test_tool_call(
+@pytest.mark.vcr(
+    decode_compressed_response=True,
+    before_record_request=lambda _: _.headers.clear() or _,
+    before_record_response=lambda _: {**_, "headers": {}},
+)
+def test_tool_calls(
     in_memory_span_exporter: InMemorySpanExporter,
     tracer_provider: trace_api.TracerProvider,
 ) -> None:
@@ -58,25 +61,106 @@ def test_tool_call(
             },
         ),
     ]
-    with suppress(openai.APIConnectionError):
-        client.chat.completions.create(
-            model="gpt-4",
-            tools=input_tools,
-            messages=[
-                {
-                    "role": "user",
-                    "content": "What's the weather like in San Francisco?",
-                },
-            ],
-        )
+    client.chat.completions.create(
+        extra_headers={"Accept-Encoding": "gzip"},
+        model="gpt-4o-mini",
+        tools=input_tools,
+        messages=[
+            {
+                "role": "assistant",
+                "tool_calls": [
+                    {
+                        "id": "call_62136355",
+                        "type": "function",
+                        "function": {"name": "get_weather", "arguments": '{"city": "New York"}'},
+                    },
+                    {
+                        "id": "call_62136356",
+                        "type": "function",
+                        "function": {"name": "get_population", "arguments": '{"city": "New York"}'},
+                    },
+                ],
+            },
+            {
+                "role": "tool",
+                "tool_call_id": "call_62136355",
+                "content": '{"city": "New York", "weather": "fine"}',
+            },
+            {
+                "role": "tool",
+                "tool_call_id": "call_62136356",
+                "content": '{"city": "New York", "weather": "large"}',
+            },
+            {
+                "role": "assistant",
+                "content": "In New York the weather is fine and the population is large.",
+            },
+            {
+                "role": "user",
+                "content": "What's the weather and population in San Francisco?",
+            },
+        ],
+    )
     spans = in_memory_span_exporter.get_finished_spans()
-    assert len(spans) == 4
-    span = spans[3]
-    attributes = span.attributes or dict()
+    assert len(spans) == 1
+    span = spans[0]
+    attributes = dict(span.attributes or {})
     for i in range(len(input_tools)):
-        json_schema = attributes.get(f"llm.tools.{i}.tool.json_schema")
+        json_schema = attributes.pop(f"llm.tools.{i}.tool.json_schema")
         assert isinstance(json_schema, str)
         assert json.loads(json_schema)
+    assert (
+        attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.id") == "call_62136355"
+    )
+    assert (
+        attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.function.name")
+        == "get_weather"
+    )
+    assert (
+        attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.function.arguments")
+        == '{"city": "New York"}'
+    )
+    assert (
+        attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.id") == "call_62136356"
+    )
+    assert (
+        attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.function.name")
+        == "get_population"
+    )
+    assert (
+        attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.function.arguments")
+        == '{"city": "New York"}'
+    )
+    assert attributes.pop("llm.input_messages.1.message.role") == "tool"
+    assert attributes.pop("llm.input_messages.1.message.tool_call_id") == "call_62136355"
+    assert (
+        attributes.pop("llm.input_messages.1.message.content")
+        == '{"city": "New York", "weather": "fine"}'
+    )
+    assert attributes.pop("llm.input_messages.2.message.role") == "tool"
+    assert attributes.pop("llm.input_messages.2.message.tool_call_id") == "call_62136356"
+    assert (
+        attributes.pop("llm.input_messages.2.message.content")
+        == '{"city": "New York", "weather": "large"}'
+    )
+    assert attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id")
+    assert (
+        attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.function.name")
+        == "get_weather"
+    )
+    assert (
+        attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.function.arguments")
+        == '{"city": "San Francisco"}'
+    )
+    assert attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.id")
+    assert (
+        attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.function.name")
+        == "get_population"
+    )
+    assert (
+        attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.function.arguments")
+        == '{"city": "San Francisco"}'
+    )
 
 
 def _openai_version() -> Tuple[int, int, int]:

From 6baa7047f6c40f3e3395f25d8b4845cc94923c78 Mon Sep 17 00:00:00 2001
From: Roger Yang <roger.yang@arize.com>
Date: Wed, 30 Oct 2024 15:02:06 -0700
Subject: [PATCH 2/3] fix tests

---
 .../openai/test_instrumentor.py                | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py
index 61e5af613..a32e9043c 100644
--- a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py
+++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py
@@ -225,6 +225,15 @@ async def task() -> None:
             )
             # We left out model_name from our mock stream.
             assert attributes.pop(LLM_MODEL_NAME, None) == model_name
+        else:
+            assert (
+                attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id")
+                == "call_amGrubFmr2FSPHeC5OPgwcNs"
+            )
+            assert (
+                attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.id")
+                == "call_6QTP4mLSYYzZwt3ZWj77vfZf"
+            )
     if use_context_attributes:
         _check_context_attributes(
             attributes,
@@ -651,6 +660,15 @@ async def task() -> None:
             )
             # We left out model_name from our mock stream.
             assert attributes.pop(LLM_MODEL_NAME, None) == model_name
+        else:
+            assert (
+                attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id")
+                == "call_amGrubFmr2FSPHeC5OPgwcNs"
+            )
+            assert (
+                attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.id")
+                == "call_6QTP4mLSYYzZwt3ZWj77vfZf"
+            )
     if use_context_attributes:
         _check_context_attributes(
             attributes,

From caa46444a9789d31d986b95ac41f1569c2f8951e Mon Sep 17 00:00:00 2001
From: Roger Yang <80478925+RogerHYang@users.noreply.github.com>
Date: Wed, 30 Oct 2024 15:26:28 -0700
Subject: [PATCH 3/3] fix tests

---
 .../openinference/instrumentation/openai/test_instrumentor.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py
index a32e9043c..0e3baf40a 100644
--- a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py
+++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py
@@ -225,7 +225,7 @@ async def task() -> None:
             )
             # We left out model_name from our mock stream.
             assert attributes.pop(LLM_MODEL_NAME, None) == model_name
-        else:
+        elif _openai_version() >= (1, 12, 0):
             assert (
                 attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id")
                 == "call_amGrubFmr2FSPHeC5OPgwcNs"
@@ -660,7 +660,7 @@ async def task() -> None:
             )
             # We left out model_name from our mock stream.
             assert attributes.pop(LLM_MODEL_NAME, None) == model_name
-        else:
+        elif _openai_version() >= (1, 12, 0):
             assert (
                 attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id")
                 == "call_amGrubFmr2FSPHeC5OPgwcNs"