From 3d03d512f1b4464eb158ba7928c5789e9f803605 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 23 Sep 2024 13:20:08 -0700 Subject: [PATCH 1/5] langgraph[patch]: remove stringify of tool msg contnt --- .../langgraph/langgraph/prebuilt/tool_node.py | 14 ---------- libs/langgraph/tests/test_prebuilt.py | 28 +++++++++++++++++++ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index a596e829f..cf10f5b0b 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -43,16 +43,6 @@ TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." -def str_output(output: Any) -> str: - if isinstance(output, str): - return output - else: - try: - return json.dumps(output, ensure_ascii=False) - except Exception: - return str(output) - - class ToolNode(RunnableCallable): """A node that runs the tools called in the last AIMessage. @@ -138,8 +128,6 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke( input, config ) - # TODO: handle this properly in core - tool_message.content = str_output(tool_message.content) return tool_message except Exception as e: if not self.handle_tool_errors: @@ -155,8 +143,6 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke( input, config ) - # TODO: handle this properly in core - tool_message.content = str_output(tool_message.content) return tool_message except Exception as e: if not self.handle_tool_errors: diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 42dca7ce9..faf3f38ef 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -247,6 +247,12 @@ async def tool3(some_val: int, some_other_val: str) -> str: {"key_1": some_other_val, "key_2": "baz"}, ] + async def tool4(some_val: int, some_other_val: str) -> str: + """Tool 3 docstring.""" + return [ + {"type": "image_url", "image_url": {"url": "abdc"}}, + ] + result = ToolNode([tool1]).invoke( { "messages": [ @@ -382,6 +388,28 @@ async def tool3(some_val: int, some_other_val: str) -> str: ) assert tool_message.tool_call_id == "some 0" + # list of content blocks tool content + result3 = await ToolNode([tool4]).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool4", + "args": {"some_val": 2, "some_other_val": "bar"}, + "id": "some 0", + } + ], + ) + ] + } + ) + tool_message: ToolMessage = result3["messages"][-1] + assert tool_message.type == "tool" + assert tool_message.content == [{"type": "image_url", "image_url": {"url": "abdc"}}] + assert tool_message.tool_call_id == "some 0" + def my_function(some_val: int, some_other_val: str) -> str: return f"{some_val} - {some_other_val}" From dc5e6df8ed39a6606c12383e328cf8fb6254b464 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 23 Sep 2024 13:46:55 -0700 Subject: [PATCH 2/5] fix --- .../langgraph/langgraph/prebuilt/tool_node.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index cf10f5b0b..8464ab61a 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -43,6 +43,27 @@ TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." +def msg_content_output(output: Any) -> str | List[dict]: + recognized_content_block_types = ("image", "image_url", "text", "json") + if isinstance(output, str): + return output + elif isinstance(output, list) and all( + [ + isinstance(x, dict) and x.get("type") in recognized_content_block_types + for x in output + ] + ): + return output + # Technically a list of strings is also valid message content but it's not currently + # well tested that all chat models support this. And for backwards compatibility + # we want to make sure we don't break any existing ToolNode usage. + else: + try: + return json.dumps(output, ensure_ascii=False) + except Exception: + return str(output) + + class ToolNode(RunnableCallable): """A node that runs the tools called in the last AIMessage. @@ -128,6 +149,8 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke( input, config ) + # TODO: handle this properly in core + tool_message.content = msg_content_output(tool_message.content) return tool_message except Exception as e: if not self.handle_tool_errors: @@ -143,6 +166,8 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke( input, config ) + # TODO: handle this properly in core + tool_message.content = msg_content_output(tool_message.content) return tool_message except Exception as e: if not self.handle_tool_errors: From 5bba6baf9629337d451144f31805f379afaabc66 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 23 Sep 2024 13:49:33 -0700 Subject: [PATCH 3/5] lint --- libs/langgraph/langgraph/prebuilt/tool_node.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index 8464ab61a..113ca8c81 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -149,8 +149,9 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke( input, config ) - # TODO: handle this properly in core - tool_message.content = msg_content_output(tool_message.content) + tool_message.content = cast( + Union[str, list], msg_content_output(tool_message.content) + ) return tool_message except Exception as e: if not self.handle_tool_errors: @@ -166,8 +167,9 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke( input, config ) - # TODO: handle this properly in core - tool_message.content = msg_content_output(tool_message.content) + tool_message.content = cast( + Union[str, list], msg_content_output(tool_message.content) + ) return tool_message except Exception as e: if not self.handle_tool_errors: From 58018e2b7788e6eace394bf9bc39f0571b6cbebf Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 23 Sep 2024 13:52:31 -0700 Subject: [PATCH 4/5] nit --- libs/langgraph/tests/test_prebuilt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 60da49cd2..6e1cd081a 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -263,7 +263,7 @@ async def tool3(some_val: int, some_other_val: str) -> str: ] async def tool4(some_val: int, some_other_val: str) -> str: - """Tool 3 docstring.""" + """Tool 4 docstring.""" return [ {"type": "image_url", "image_url": {"url": "abdc"}}, ] @@ -404,7 +404,7 @@ async def tool4(some_val: int, some_other_val: str) -> str: assert tool_message.tool_call_id == "some 0" # list of content blocks tool content - result3 = await ToolNode([tool4]).ainvoke( + result4 = await ToolNode([tool4]).ainvoke( { "messages": [ AIMessage( @@ -420,7 +420,7 @@ async def tool4(some_val: int, some_other_val: str) -> str: ] } ) - tool_message: ToolMessage = result3["messages"][-1] + tool_message: ToolMessage = result4["messages"][-1] assert tool_message.type == "tool" assert tool_message.content == [{"type": "image_url", "image_url": {"url": "abdc"}}] assert tool_message.tool_call_id == "some 0" From dca48e02ba5581cf2bc36333708e1ac181a401ee Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 23 Sep 2024 13:56:23 -0700 Subject: [PATCH 5/5] nit --- libs/langgraph/langgraph/prebuilt/tool_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index 4de3f0ce0..be87c0f0f 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -47,7 +47,7 @@ def msg_content_output(output: Any) -> str | List[dict]: recognized_content_block_types = ("image", "image_url", "text", "json") if isinstance(output, str): return output - elif isinstance(output, list) and all( + elif all( [ isinstance(x, dict) and x.get("type") in recognized_content_block_types for x in output