Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph[patch]: remove stringify of tool msg contnt #1810

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/tool_node.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 59.2 ms +- 1.2 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (7.10 ms) is 12% of the mean (57.0 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. fanout_to_subgraph_10x_sync: Mean +- std dev: 57.0 ms +- 7.1 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 76.8 ms +- 1.7 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 81.4 ms +- 1.2 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 562 ms +- 16 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 505 ms +- 6 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 815 ms +- 32 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 792 ms +- 8 ms ......................................... react_agent_10x: Mean +- std dev: 41.4 ms +- 3.3 ms ......................................... react_agent_10x_sync: Mean +- std dev: 29.8 ms +- 0.6 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 52.7 ms +- 1.3 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 43.2 ms +- 3.7 ms ......................................... react_agent_100x: Mean +- std dev: 420 ms +- 10 ms ......................................... react_agent_100x_sync: Mean +- std dev: 337 ms +- 5 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 946 ms +- 19 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 842 ms +- 19 ms ......................................... wide_state_25x300: Mean +- std dev: 20.6 ms +- 0.3 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 13.0 ms +- 0.3 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 238 ms +- 7 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 240 ms +- 15 ms ......................................... wide_state_15x600: Mean +- std dev: 23.8 ms +- 0.3 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 15.0 ms +- 0.3 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 414 ms +- 10 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 417 ms +- 18 ms ......................................... wide_state_9x1200: Mean +- std dev: 23.7 ms +- 0.3 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 14.8 ms +- 0.5 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 268 ms +- 10 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 267 ms +- 15 ms

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/tool_node.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_100x | 592 ms | 562 ms: 1.05x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 851 ms | 815 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 59.3 ms | 57.0 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x | 43.0 ms | 41.4 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 54.2 ms | 52.7 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 79.0 ms | 76.8 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 44.3 ms | 43.2 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 970 ms | 946 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 810 ms | 792 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 83.2 ms | 81.4 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 860 ms | 842 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 241 ms | 238 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 272 ms | 268 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 271 ms | 267 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 60.1 ms | 59.2 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 15.0 ms | 14.8 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 30.2 ms | 29.8 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 511 ms | 505 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 419 ms | 414 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 23.9 ms | 23.7 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.02x faster | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (8): wide_state_15x600_checkpoint_sync, wide_state_25x300_checkpoint_sync, react_agent_100x, react_agent_100x_sync, wide_state_15x600, wide_state_15x600_sync, wide_state_25x300, wide_state_25x300_sync

import asyncio
import json
Expand Down Expand Up @@ -43,9 +43,20 @@
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."


def str_output(output: Any) -> str:
def msg_content_output(output: Any) -> str | List[dict]:
recognized_content_block_types = ("image", "image_url", "text", "json")
if isinstance(output, str):
return output
elif all(
[
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
isinstance(x, dict) and x.get("type") in recognized_content_block_types
for x in output
]
):
return output
# Technically a list of strings is also valid message content but it's not currently
# well tested that all chat models support this. And for backwards compatibility
# we want to make sure we don't break any existing ToolNode usage.
else:
try:
return json.dumps(output, ensure_ascii=False)
Expand Down Expand Up @@ -138,8 +149,9 @@
tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke(
input, config
)
# TODO: handle this properly in core
tool_message.content = str_output(tool_message.content)
tool_message.content = cast(
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
except Exception as e:
if not self.handle_tool_errors:
Expand All @@ -155,8 +167,9 @@
tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke(
input, config
)
# TODO: handle this properly in core
tool_message.content = str_output(tool_message.content)
tool_message.content = cast(
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
except Exception as e:
if not self.handle_tool_errors:
Expand Down
28 changes: 28 additions & 0 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ async def tool3(some_val: int, some_other_val: str) -> str:
{"key_1": some_other_val, "key_2": "baz"},
]

async def tool4(some_val: int, some_other_val: str) -> str:
"""Tool 4 docstring."""
return [
{"type": "image_url", "image_url": {"url": "abdc"}},
]

result = ToolNode([tool1]).invoke(
{
"messages": [
Expand Down Expand Up @@ -397,6 +403,28 @@ async def tool3(some_val: int, some_other_val: str) -> str:
)
assert tool_message.tool_call_id == "some 0"

# list of content blocks tool content
result4 = await ToolNode([tool4]).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool4",
"args": {"some_val": 2, "some_other_val": "bar"},
"id": "some 0",
}
],
)
]
}
)
tool_message: ToolMessage = result4["messages"][-1]
assert tool_message.type == "tool"
assert tool_message.content == [{"type": "image_url", "image_url": {"url": "abdc"}}]
assert tool_message.tool_call_id == "some 0"


def my_function(some_val: int, some_other_val: str) -> str:
return f"{some_val} - {some_other_val}"
Expand Down
Loading