Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph[patch]: refactor ToolNode #1066

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 51 additions & 62 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Union
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Tuple, Union, cast

from langchain_core.messages import AIMessage, AnyMessage, ToolCall, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.runnables.config import get_config_list, get_executor_for_config
from langchain_core.tools import BaseTool
from langchain_core.tools import tool as create_tool

Expand Down Expand Up @@ -60,47 +60,49 @@ def __init__(
def _func(
self, input: Union[list[AnyMessage], dict[str, Any]], config: RunnableConfig
) -> Any:
if isinstance(input, list):
output_type = "list"
message: AnyMessage = input[-1]
elif messages := input.get("messages", []):
output_type = "dict"
message = messages[-1]
else:
raise ValueError("No message found in input")

if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")

def run_one(call: ToolCall):
if (requested_tool := call["name"]) not in self.tools_by_name:
content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format(
requested_tool=requested_tool,
available_tools=", ".join(self.tools_by_name.keys()),
)
return ToolMessage(
content, name=requested_tool, tool_call_id=call["id"]
)

try:
input = {**call, **{"type": "tool_call"}}
return self.tools_by_name[call["name"]].invoke(input, config)
except Exception as e:
if not self.handle_tool_errors:
raise e
content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
return ToolMessage(content, name=call["name"], tool_call_id=call["id"])

message, output_type = self._parse_input(input)
config_list = get_config_list(config, len(message.tool_calls))
with get_executor_for_config(config) as executor:
outputs = [*executor.map(run_one, message.tool_calls)]
if output_type == "list":
return outputs
else:
return {"messages": outputs}
outputs = [*executor.map(self._run_one, message.tool_calls, config_list)]
return outputs if output_type == "list" else {"messages": outputs}

async def _afunc(
self, input: Union[list[AnyMessage], dict[str, Any]], config: RunnableConfig
) -> Any:
message, output_type = self._parse_input(input)
outputs = await asyncio.gather(
*(self._arun_one(call, config) for call in message.tool_calls)
)
return outputs if output_type == "list" else {"messages": outputs}

def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message

try:
input = {**call, **{"type": "tool_call"}}
return self.tools_by_name[call["name"]].invoke(input, config)
except Exception as e:
if not self.handle_tool_errors:
raise e
content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
return ToolMessage(content, name=call["name"], tool_call_id=call["id"])

async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message
try:
input = {**call, **{"type": "tool_call"}}
return await self.tools_by_name[call["name"]].ainvoke(input, config)
except Exception as e:
if not self.handle_tool_errors:
raise e
content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
return ToolMessage(content, name=call["name"], tool_call_id=call["id"])

def _parse_input(
self, input: Union[list[AnyMessage], dict[str, Any]]
) -> Tuple[AIMessage, Literal["list", "dict"]]:
if isinstance(input, list):
output_type = "list"
message: AnyMessage = input[-1]
Expand All @@ -112,31 +114,18 @@ async def _afunc(

if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")

async def run_one(call: ToolCall):
if (requested_tool := call["name"]) not in self.tools_by_name:
content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format(
requested_tool=requested_tool,
available_tools=", ".join(self.tools_by_name.keys()),
)
return ToolMessage(
content, name=requested_tool, tool_call_id=call["id"]
)

try:
input = {**call, **{"type": "tool_call"}}
return await self.tools_by_name[call["name"]].ainvoke(input, config)
except Exception as e:
if not self.handle_tool_errors:
raise e
content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
return ToolMessage(content, name=call["name"], tool_call_id=call["id"])

outputs = await asyncio.gather(*(run_one(call) for call in message.tool_calls))
if output_type == "list":
return outputs
else:
return {"messages": outputs}
return cast(AIMessage, message), output_type

def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]:
if (requested_tool := call["name"]) not in self.tools_by_name:
content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format(
requested_tool=requested_tool,
available_tools=", ".join(self.tools_by_name.keys()),
)
return ToolMessage(content, name=requested_tool, tool_call_id=call["id"])
else:
return None


def tools_condition(
Expand Down
Loading