From 6d3bd67f9c3546b1096e6250784c74165af9f97c Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Fri, 20 Sep 2024 10:10:09 -0700 Subject: [PATCH] Add return_direct support --- .../langgraph/prebuilt/chat_agent_executor.py | 45 ++++--- .../tests/__snapshots__/test_pregel.ambr | 8 +- libs/langgraph/tests/test_prebuilt.py | 112 +++++++++++++++++- 3 files changed, 133 insertions(+), 32 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py index dd2cddb5e..1e2209bd7 100644 --- a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py +++ b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py @@ -11,17 +11,13 @@ ) from langchain_core.language_models import BaseChatModel -from langchain_core.messages import ( - AIMessage, - BaseMessage, - SystemMessage, -) +from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda from langchain_core.tools import BaseTool from langgraph._api.deprecation import deprecated_parameter from langgraph.checkpoint.base import BaseCheckpointSaver -from langgraph.graph import END, StateGraph +from langgraph.graph import StateGraph from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import add_messages from langgraph.managed import IsLastStep @@ -430,15 +426,15 @@ class Agent,Tools otherClass model = model.bind_tools(tool_classes) # Define the function that determines whether to continue or not - def should_continue(state: AgentState) -> Literal["continue", "end"]: + def should_continue(state: AgentState) -> Literal["tools", "__end__"]: messages = state["messages"] last_message = messages[-1] # If there is no function call, then we finish if not isinstance(last_message, AIMessage) or not last_message.tool_calls: - return "end" + return "__end__" # Otherwise if there is, we continue else: - return "continue" + return "tools" preprocessor = _get_model_preprocessing_runnable(state_modifier, messages_modifier) model_runnable = preprocessor | model @@ -498,23 +494,24 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: "agent", # Next, we pass in the function that will determine which node is called next. should_continue, - # Finally we pass in a mapping. - # The keys are strings, and the values are other nodes. - # END is a special node marking that the graph should finish. - # What will happen is we will call `should_continue`, and then the output of that - # will be matched against the keys in this mapping. - # Based on which one it matches, that node will then be called. - { - # If `tools`, then we call the tool node. - "continue": "tools", - # Otherwise we finish. - "end": END, - }, ) - # We now add a normal edge from `tools` to `agent`. - # This means that after `tools` is called, `agent` node is called next. - workflow.add_edge("tools", "agent") + # If any of the tools are configured to return_directly after running, + # our graph needs to check if these were called + should_return_direct = {t.name for t in tool_classes if t.return_direct} + + def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]: + for m in reversed(state["messages"]): + if not isinstance(m, ToolMessage): + break + if m.name in should_return_direct: + return "__end__" + return "agent" + + if should_return_direct: + workflow.add_conditional_edges("tools", route_tool_responses) + else: + workflow.add_edge("tools", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, diff --git a/libs/langgraph/tests/__snapshots__/test_pregel.ambr b/libs/langgraph/tests/__snapshots__/test_pregel.ambr index 9354409a7..7928f2628 100644 --- a/libs/langgraph/tests/__snapshots__/test_pregel.ambr +++ b/libs/langgraph/tests/__snapshots__/test_pregel.ambr @@ -4857,13 +4857,11 @@ { "source": "agent", "target": "tools", - "data": "continue", "conditional": true }, { "source": "agent", "target": "__end__", - "data": "end", "conditional": true } ] @@ -4875,8 +4873,8 @@ graph TD; __start__ --> agent; tools --> agent; - agent -.  continue  .-> tools; - agent -.  end  .-> __end__; + agent -.-> tools; + agent -.-> __end__; ''' # --- @@ -5017,7 +5015,7 @@ '{"title": "LangGraphOutput", "type": "object", "properties": {"input": {"title": "Input", "type": "string"}, "agent_outcome": {"title": "Agent Outcome", "anyOf": [{"$ref": "#/definitions/AgentAction"}, {"$ref": "#/definitions/AgentFinish"}]}, "intermediate_steps": {"title": "Intermediate Steps", "type": "array", "items": {"type": "array", "minItems": 2, "maxItems": 2, "items": [{"$ref": "#/definitions/AgentAction"}, {"type": "string"}]}}}, "definitions": {"AgentAction": {"title": "AgentAction", "description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "type": "object", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"title": "Tool Input", "anyOf": [{"type": "string"}, {"type": "object"}]}, "log": {"title": "Log", "type": "string"}, "type": {"title": "Type", "default": "AgentAction", "enum": ["AgentAction"], "type": "string"}}, "required": ["tool", "tool_input", "log"]}, "AgentFinish": {"title": "AgentFinish", "description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "type": "object", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"title": "Type", "default": "AgentFinish", "enum": ["AgentFinish"], "type": "string"}}, "required": ["return_values", "log"]}}}' # --- # name: test_state_graph_w_config_inherited_state_keys - '{"$defs": {"Configurable": {"properties": {"tools": {"default": null, "items": {"type": "string"}, "title": "Tools", "type": "array"}}, "title": "Configurable", "type": "object"}}, "properties": {"configurable": {"allOf": [{"$ref": "#/$defs/Configurable"}], "default": null}}, "title": "LangGraphConfig", "type": "object"}' + '{"$defs": {"Configurable": {"properties": {"tools": {"default": null, "items": {"type": "string"}, "title": "Tools", "type": "array"}}, "title": "Configurable", "type": "object"}}, "properties": {"configurable": {"$ref": "#/$defs/Configurable", "default": null}}, "title": "LangGraphConfig", "type": "object"}' # --- # name: test_state_graph_w_config_inherited_state_keys.1 '{"$defs": {"AgentAction": {"description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"anyOf": [{"type": "string"}, {"type": "object"}], "title": "Tool Input"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentAction", "default": "AgentAction", "enum": ["AgentAction"], "title": "Type", "type": "string"}}, "required": ["tool", "tool_input", "log"], "title": "AgentAction", "type": "object"}, "AgentFinish": {"description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentFinish", "default": "AgentFinish", "enum": ["AgentFinish"], "title": "Type", "type": "string"}}, "required": ["return_values", "log"], "title": "AgentFinish", "type": "object"}}, "properties": {"input": {"title": "Input", "type": "string"}, "agent_outcome": {"anyOf": [{"$ref": "#/$defs/AgentAction"}, {"$ref": "#/$defs/AgentFinish"}, {"type": "null"}], "default": null, "title": "Agent Outcome"}, "intermediate_steps": {"default": null, "items": {"maxItems": 2, "minItems": 2, "prefixItems": [{"$ref": "#/$defs/AgentAction"}, {"type": "string"}], "type": "array"}, "title": "Intermediate Steps", "type": "array"}}, "required": ["input"], "title": "LangGraphInput", "type": "object"}' diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 42dca7ce9..4efaa35e4 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -47,6 +47,9 @@ class FakeToolCallingModel(BaseChatModel): + tool_calls: Optional[list[list[ToolCall]]] = None + index: int = 0 + def _generate( self, messages: List[BaseMessage], @@ -56,7 +59,15 @@ def _generate( ) -> ChatResult: """Top Level call""" messages_string = "-".join([m.content for m in messages]) - message = AIMessage(content=messages_string, id="0") + tool_calls = ( + self.tool_calls[self.index % len(self.tool_calls)] + if self.tool_calls + else [] + ) + message = AIMessage( + content=messages_string, id=str(self.index), tool_calls=tool_calls.copy() + ) + self.index += 1 return ChatResult(generations=[ChatGeneration(message=message)]) @property @@ -68,8 +79,6 @@ def bind_tools( tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: - if len(tools) > 0: - raise ValueError("Not supported yet!") return self @@ -584,3 +593,100 @@ def get_day_list(days: list[str]) -> list[str]: [AIMessage(content="", tool_calls=tool_calls)] ) assert outputs[0].content == json.dumps(data, ensure_ascii=False) + + +async def test_return_direct() -> None: + @dec_tool(return_direct=True) + def tool_return_direct(input: str) -> str: + """A tool that returns directly.""" + return f"Direct result: {input}" + + @dec_tool + def tool_normal(input: str) -> str: + """A normal tool.""" + return f"Normal result: {input}" + + first_tool_call = [ + ToolCall( + name="tool_return_direct", + args={"input": "Test direct"}, + id="1", + ), + ] + expected_ai = AIMessage( + content="Test direct", + id="0", + tool_calls=first_tool_call, + ) + model = FakeToolCallingModel(tool_calls=[first_tool_call, []]) + agent = create_react_agent(model, [tool_return_direct, tool_normal]) + + # Test direct return for tool_return_direct + result = agent.invoke( + {"messages": [HumanMessage(content="Test direct", id="hum0")]} + ) + assert result["messages"] == [ + HumanMessage(content="Test direct", id="hum0"), + expected_ai, + ToolMessage( + content="Direct result: Test direct", + name="tool_return_direct", + tool_call_id="1", + id=result["messages"][2].id, + ), + ] + second_tool_call = [ + ToolCall( + name="tool_normal", + args={"input": "Test normal"}, + id="2", + ), + ] + model = FakeToolCallingModel(tool_calls=[second_tool_call, []]) + agent = create_react_agent(model, [tool_return_direct, tool_normal]) + result = agent.invoke( + {"messages": [HumanMessage(content="Test normal", id="hum1")]} + ) + assert result["messages"] == [ + HumanMessage(content="Test normal", id="hum1"), + AIMessage(content="Test normal", id="0", tool_calls=second_tool_call), + ToolMessage( + content="Normal result: Test normal", + name="tool_normal", + tool_call_id="2", + id=result["messages"][2].id, + ), + AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"), + ] + + both_tool_calls = [ + ToolCall( + name="tool_return_direct", + args={"input": "Test both direct"}, + id="3", + ), + ToolCall( + name="tool_normal", + args={"input": "Test both normal"}, + id="4", + ), + ] + model = FakeToolCallingModel(tool_calls=[both_tool_calls, []]) + agent = create_react_agent(model, [tool_return_direct, tool_normal]) + result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]}) + assert result["messages"] == [ + HumanMessage(content="Test both", id="hum2"), + AIMessage(content="Test both", id="0", tool_calls=both_tool_calls), + ToolMessage( + content="Direct result: Test both direct", + name="tool_return_direct", + tool_call_id="3", + id=result["messages"][2].id, + ), + ToolMessage( + content="Normal result: Test both normal", + name="tool_normal", + tool_call_id="4", + id=result["messages"][3].id, + ), + ]