From 9b2f51e40a83cbc88bf47b84cf2afbbefb85bd8b Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 18 Jul 2024 15:24:41 -0700 Subject: [PATCH] langgraph[patch]: refactor ToolNode --- .../langgraph/langgraph/prebuilt/tool_node.py | 113 ++++++++---------- 1 file changed, 51 insertions(+), 62 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index 1aec005fb..fc70d261b 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -1,9 +1,9 @@ import asyncio -from typing import Any, Callable, Dict, Literal, Optional, Sequence, Union +from typing import Any, Callable, Dict, Literal, Optional, Sequence, Tuple, Union, cast from langchain_core.messages import AIMessage, AnyMessage, ToolCall, ToolMessage from langchain_core.runnables import RunnableConfig -from langchain_core.runnables.config import get_executor_for_config +from langchain_core.runnables.config import get_config_list, get_executor_for_config from langchain_core.tools import BaseTool from langchain_core.tools import tool as create_tool @@ -60,47 +60,49 @@ def __init__( def _func( self, input: Union[list[AnyMessage], dict[str, Any]], config: RunnableConfig ) -> Any: - if isinstance(input, list): - output_type = "list" - message: AnyMessage = input[-1] - elif messages := input.get("messages", []): - output_type = "dict" - message = messages[-1] - else: - raise ValueError("No message found in input") - - if not isinstance(message, AIMessage): - raise ValueError("Last message is not an AIMessage") - - def run_one(call: ToolCall): - if (requested_tool := call["name"]) not in self.tools_by_name: - content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format( - requested_tool=requested_tool, - available_tools=", ".join(self.tools_by_name.keys()), - ) - return ToolMessage( - content, name=requested_tool, tool_call_id=call["id"] - ) - - try: - input = {**call, **{"type": "tool_call"}} - return self.tools_by_name[call["name"]].invoke(input, config) - except Exception as e: - if not self.handle_tool_errors: - raise e - content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) - return ToolMessage(content, name=call["name"], tool_call_id=call["id"]) - + message, output_type = self._parse_input(input) + config_list = get_config_list(config, len(message.tool_calls)) with get_executor_for_config(config) as executor: - outputs = [*executor.map(run_one, message.tool_calls)] - if output_type == "list": - return outputs - else: - return {"messages": outputs} + outputs = [*executor.map(self._run_one, message.tool_calls, config_list)] + return outputs if output_type == "list" else {"messages": outputs} async def _afunc( self, input: Union[list[AnyMessage], dict[str, Any]], config: RunnableConfig ) -> Any: + message, output_type = self._parse_input(input) + outputs = await asyncio.gather( + *(self._arun_one(call, config) for call in message.tool_calls) + ) + return outputs if output_type == "list" else {"messages": outputs} + + def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + + try: + input = {**call, **{"type": "tool_call"}} + return self.tools_by_name[call["name"]].invoke(input, config) + except Exception as e: + if not self.handle_tool_errors: + raise e + content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) + return ToolMessage(content, name=call["name"], tool_call_id=call["id"]) + + async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + try: + input = {**call, **{"type": "tool_call"}} + return await self.tools_by_name[call["name"]].ainvoke(input, config) + except Exception as e: + if not self.handle_tool_errors: + raise e + content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) + return ToolMessage(content, name=call["name"], tool_call_id=call["id"]) + + def _parse_input( + self, input: Union[list[AnyMessage], dict[str, Any]] + ) -> Tuple[AIMessage, Literal["list", "dict"]]: if isinstance(input, list): output_type = "list" message: AnyMessage = input[-1] @@ -112,31 +114,18 @@ async def _afunc( if not isinstance(message, AIMessage): raise ValueError("Last message is not an AIMessage") - - async def run_one(call: ToolCall): - if (requested_tool := call["name"]) not in self.tools_by_name: - content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format( - requested_tool=requested_tool, - available_tools=", ".join(self.tools_by_name.keys()), - ) - return ToolMessage( - content, name=requested_tool, tool_call_id=call["id"] - ) - - try: - input = {**call, **{"type": "tool_call"}} - return await self.tools_by_name[call["name"]].ainvoke(input, config) - except Exception as e: - if not self.handle_tool_errors: - raise e - content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) - return ToolMessage(content, name=call["name"], tool_call_id=call["id"]) - - outputs = await asyncio.gather(*(run_one(call) for call in message.tool_calls)) - if output_type == "list": - return outputs else: - return {"messages": outputs} + return cast(AIMessage, message), output_type + + def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]: + if (requested_tool := call["name"]) not in self.tools_by_name: + content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format( + requested_tool=requested_tool, + available_tools=", ".join(self.tools_by_name.keys()), + ) + return ToolMessage(content, name=requested_tool, tool_call_id=call["id"]) + else: + return None def tools_condition(