Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rfc] langgraph: check if model passed as runnable binding with tools in create_react_agent #1647

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
BaseMessage,
SystemMessage,
)
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.runnables import (
Runnable,
RunnableBinding,
RunnableConfig,
RunnableLambda,
)
from langchain_core.tools import BaseTool

from langgraph._api.deprecation import deprecated_parameter
Expand Down Expand Up @@ -127,6 +132,40 @@ def _get_model_preprocessing_runnable(
return _get_state_modifier_runnable(state_modifier)


def _should_bind_tools(model: LanguageModelLike, tools: Sequence[BaseTool]) -> bool:
if not isinstance(model, RunnableBinding):
return False

if "tools" not in model.kwargs:
return False

bound_tools = model.kwargs["tools"]
if len(tools) != len(bound_tools):
raise ValueError(
"Number of tools in the model.bind_tools() and tools passed to create_react_agent must match"
)

tool_names = set(tool.name for tool in tools)
bound_tool_names = set()
for bound_tool in bound_tools:
# OpenAI-style tool
if bound_tool.get("type") == "function":
bound_tool_name = bound_tool["function"]["name"]
# Anthropic-style tool
elif bound_tool.get("name"):
bound_tool_name = bound_tool["name"]
else:
# unknown tool type so we'll ignore it
continue

bound_tool_names.add(bound_tool_name)

if missing_tools := tool_names - bound_tool_names:
raise ValueError(f"Missing tools '{missing_tools}' in the model.bind_tools()")

return True


@deprecated_parameter("messages_modifier", "0.1.9", "state_modifier", removal="0.3.0")
def create_react_agent(
model: LanguageModelLike,
Expand Down Expand Up @@ -426,7 +465,9 @@ class Agent,Tools otherClass
else:
tool_classes = tools
tool_node = ToolNode(tool_classes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this but thoughts on getting the tool_classes from the tool node here too so we can let you pass in raw functions ?

        tool_classes = tools.tools_by_name.values()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think in tools_by_name it will also be a tool as we're calling create_tool under the hood https://github.com/langchain-ai/langgraph/blob/main/libs/langgraph/langgraph/prebuilt/tool_node.py#L95

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh nevermind, i misread first, yes, we can definitely do that!

model = model.bind_tools(tool_classes)

if _should_bind_tools(model, tool_classes):
model = model.bind_tools(tool_classes)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState):
Expand Down
Loading