Skip to content

Commit

Permalink
feat(anthropic): add tool json schema attributes to anthropic instrum…
Browse files Browse the repository at this point in the history
…entation (#1087)
  • Loading branch information
axiomofjoy authored Oct 30, 2024
1 parent 473701d commit 907b6e5
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
OpenInferenceMimeTypeValues,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolAttributes,
ToolCallAttributes,
)

Expand Down Expand Up @@ -98,6 +99,7 @@ def __call__(
_get_llm_prompts(llm_prompt),
_get_inputs(arguments),
_get_llm_invocation_parameters(llm_invocation_parameters),
_get_llm_tools(llm_invocation_parameters),
)
),
) as span:
Expand Down Expand Up @@ -150,6 +152,7 @@ async def __call__(
_get_llm_prompts(llm_prompt),
_get_inputs(arguments),
_get_llm_invocation_parameters(invocation_parameters),
_get_llm_tools(invocation_parameters),
)
),
) as span:
Expand Down Expand Up @@ -201,6 +204,7 @@ def __call__(
_get_llm_span_kind(),
_get_llm_input_messages(llm_input_messages),
_get_llm_invocation_parameters(invocation_parameters),
_get_llm_tools(invocation_parameters),
_get_inputs(arguments),
)
),
Expand Down Expand Up @@ -260,6 +264,7 @@ async def __call__(
_get_llm_span_kind(),
_get_llm_input_messages(llm_input_messages),
_get_llm_invocation_parameters(invocation_parameters),
_get_llm_tools(invocation_parameters),
_get_inputs(arguments),
)
),
Expand Down Expand Up @@ -298,6 +303,12 @@ def _get_outputs(response: "BaseModel") -> Iterator[Tuple[str, Any]]:
yield OUTPUT_MIME_TYPE, JSON


def _get_llm_tools(invocation_parameters: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]:
if isinstance(tools := invocation_parameters.get("tools"), list):
for tool_index, tool_schema in enumerate(tools):
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", safe_json_dumps(tool_schema)


def _get_llm_span_kind() -> Iterator[Tuple[str, Any]]:
yield OPENINFERENCE_SPAN_KIND, LLM

Expand Down Expand Up @@ -464,6 +475,7 @@ def _validate_invocation_parameter(parameter: Any) -> bool:
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
LLM_TOOLS = SpanAttributes.LLM_TOOLS
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
MESSAGE_CONTENTS = MessageAttributes.MESSAGE_CONTENTS
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON
Expand All @@ -479,6 +491,7 @@ def _validate_invocation_parameter(parameter: Any) -> bool:
TAG_TAGS = SpanAttributes.TAG_TAGS
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
USER_ID = SpanAttributes.USER_ID
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
LLM_SYSTEM = SpanAttributes.LLM_SYSTEM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MessageParam,
TextBlock,
TextBlockParam,
ToolParam,
ToolResultBlockParam,
ToolUseBlock,
ToolUseBlockParam,
Expand All @@ -36,6 +37,7 @@
OpenInferenceMimeTypeValues,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolAttributes,
ToolCallAttributes,
)

Expand Down Expand Up @@ -564,46 +566,43 @@ def test_anthropic_instrumentation_multiple_tool_calling(
"What is the weather like right now in New York?"
" Also what time is it there? Use necessary tools simultaneously."
)

client.messages.create(
model="claude-3-5-sonnet-20240620",
max_tokens=1024,
tools=[
{
"name": "get_weather",
"description": "Get the current weather in a given location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature,"
" either 'celsius' or 'fahrenheit'",
},
},
"required": ["location"],
get_weather_tool_schema = ToolParam(
name="get_weather",
description="Get the current weather in a given location",
input_schema={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
{
"name": "get_time",
"description": "Get the current time in a given time zone",
"input_schema": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The IANA time zone name, e.g. America/Los_Angeles",
}
},
"required": ["timezone"],
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature," " either 'celsius' or 'fahrenheit'",
},
},
],
"required": ["location"],
},
)
get_time_tool_schema = ToolParam(
name="get_time",
description="Get the current time in a given time zone",
input_schema={
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The IANA time zone name, e.g. America/Los_Angeles",
}
},
"required": ["timezone"],
},
)
client.messages.create(
model="claude-3-5-sonnet-20240620",
max_tokens=1024,
tools=[get_weather_tool_schema, get_time_tool_schema],
messages=[{"role": "user", "content": input_message}],
)

Expand All @@ -616,6 +615,10 @@ def test_anthropic_instrumentation_multiple_tool_calling(
assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_CONTENT}") == input_message
assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "user"
assert isinstance(attributes.pop(LLM_INVOCATION_PARAMETERS), str)
assert isinstance(tool_schema0 := attributes.pop(f"{LLM_TOOLS}.0.{TOOL_JSON_SCHEMA}"), str)
assert json.loads(tool_schema0) == get_weather_tool_schema
assert isinstance(tool_schema1 := attributes.pop(f"{LLM_TOOLS}.1.{TOOL_JSON_SCHEMA}"), str)
assert json.loads(tool_schema1) == get_time_tool_schema
assert isinstance(attributes.pop(INPUT_VALUE), str)
assert attributes.pop(INPUT_MIME_TYPE) == JSON
assert attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "assistant"
Expand Down Expand Up @@ -666,46 +669,43 @@ def test_anthropic_instrumentation_multiple_tool_calling_streaming(
"What is the weather like right now in New York?"
" Also what time is it there? Use necessary tools simultaneously."
)

stream = client.messages.create(
model="claude-3-5-sonnet-20240620",
max_tokens=1024,
tools=[
{
"name": "get_weather",
"description": "Get the current weather in a given location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature,"
" either 'celsius' or 'fahrenheit'",
},
},
"required": ["location"],
get_weather_tool_schema = ToolParam(
name="get_weather",
description="Get the current weather in a given location",
input_schema={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
{
"name": "get_time",
"description": "Get the current time in a given time zone",
"input_schema": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The IANA time zone name, e.g. America/Los_Angeles",
}
},
"required": ["timezone"],
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature," " either 'celsius' or 'fahrenheit'",
},
},
],
"required": ["location"],
},
)
get_time_tool_schema = ToolParam(
name="get_time",
description="Get the current time in a given time zone",
input_schema={
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The IANA time zone name, e.g. America/Los_Angeles",
}
},
"required": ["timezone"],
},
)
stream = client.messages.create(
model="claude-3-5-sonnet-20240620",
max_tokens=1024,
tools=[get_weather_tool_schema, get_time_tool_schema],
messages=[{"role": "user", "content": input_message}],
stream=True,
)
Expand All @@ -721,6 +721,10 @@ def test_anthropic_instrumentation_multiple_tool_calling_streaming(
assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_CONTENT}") == input_message
assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "user"
assert isinstance(attributes.pop(LLM_INVOCATION_PARAMETERS), str)
assert isinstance(tool_schema0 := attributes.pop(f"{LLM_TOOLS}.0.{TOOL_JSON_SCHEMA}"), str)
assert json.loads(tool_schema0) == get_weather_tool_schema
assert isinstance(tool_schema1 := attributes.pop(f"{LLM_TOOLS}.1.{TOOL_JSON_SCHEMA}"), str)
assert json.loads(tool_schema1) == get_time_tool_schema
assert isinstance(attributes.pop(INPUT_VALUE), str)
assert attributes.pop(INPUT_MIME_TYPE) == JSON
assert attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "assistant"
Expand Down Expand Up @@ -1000,6 +1004,7 @@ def test_oitracer(
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
LLM_TOOLS = SpanAttributes.LLM_TOOLS
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON
MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME
Expand All @@ -1014,6 +1019,7 @@ def test_oitracer(
TAG_TAGS = SpanAttributes.TAG_TAGS
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
LLM_PROMPT_TEMPLATE = SpanAttributes.LLM_PROMPT_TEMPLATE
LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
USER_ID = SpanAttributes.USER_ID
Expand Down

0 comments on commit 907b6e5

Please sign in to comment.