Skip to content

Commit

Permalink
Add return_direct support (#1780)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Sep 20, 2024
1 parent 7b94f1e commit 10a66ac
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 39 deletions.
45 changes: 21 additions & 24 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@
)

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
SystemMessage,
)
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool

from langgraph._api.deprecation import deprecated_parameter
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep
Expand Down Expand Up @@ -430,15 +426,15 @@ class Agent,Tools otherClass
model = model.bind_tools(tool_classes)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["continue", "end"]:
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return "end"
return "__end__"
# Otherwise if there is, we continue
else:
return "continue"
return "tools"

preprocessor = _get_model_preprocessing_runnable(state_modifier, messages_modifier)
model_runnable = preprocessor | model
Expand Down Expand Up @@ -498,23 +494,24 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "tools",
# Otherwise we finish.
"end": END,
},
)

# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")
# If any of the tools are configured to return_directly after running,
# our graph needs to check if these were called
should_return_direct = {t.name for t in tool_classes if t.return_direct}

def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]:
for m in reversed(state["messages"]):
if not isinstance(m, ToolMessage):
break
if m.name in should_return_direct:
return "__end__"
return "agent"

if should_return_direct:
workflow.add_conditional_edges("tools", route_tool_responses)
else:
workflow.add_edge("tools", "agent")

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
Expand Down
6 changes: 2 additions & 4 deletions libs/langgraph/tests/__snapshots__/test_pregel.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -4857,13 +4857,11 @@
{
"source": "agent",
"target": "tools",
"data": "continue",
"conditional": true
},
{
"source": "agent",
"target": "__end__",
"data": "end",
"conditional": true
}
]
Expand All @@ -4875,8 +4873,8 @@
graph TD;
__start__ --> agent;
tools --> agent;
agent -.  continue  .-> tools;
agent -.  end  .-> __end__;
agent -.-> tools;
agent -.-> __end__;

'''
# ---
Expand Down
134 changes: 123 additions & 11 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@


class FakeToolCallingModel(BaseChatModel):
tool_calls: Optional[list[list[ToolCall]]] = None
index: int = 0

def _generate(
self,
messages: List[BaseMessage],
Expand All @@ -56,7 +59,15 @@ def _generate(
) -> ChatResult:
"""Top Level call"""
messages_string = "-".join([m.content for m in messages])
message = AIMessage(content=messages_string, id="0")
tool_calls = (
self.tool_calls[self.index % len(self.tool_calls)]
if self.tool_calls
else []
)
message = AIMessage(
content=messages_string, id=str(self.index), tool_calls=tool_calls.copy()
)
self.index += 1
return ChatResult(generations=[ChatGeneration(message=message)])

@property
Expand All @@ -68,8 +79,6 @@ def bind_tools(
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
if len(tools) > 0:
raise ValueError("Not supported yet!")
return self


Expand Down Expand Up @@ -144,29 +153,35 @@ def test_passing_two_modifiers():


def test_system_message_modifier():
model = FakeToolCallingModel()
messages_modifier = SystemMessage(content="Foo")
agent_1 = create_react_agent(model, [], messages_modifier=messages_modifier)
agent_2 = create_react_agent(model, [], state_modifier=messages_modifier)
agent_1 = create_react_agent(
FakeToolCallingModel(), [], messages_modifier=messages_modifier
)
agent_2 = create_react_agent(
FakeToolCallingModel(), [], state_modifier=messages_modifier
)
for agent in [agent_1, agent_2]:
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {
"messages": inputs + [AIMessage(content="Foo-hi?", id="0")]
"messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])]
}
assert response == expected_response


def test_system_message_string_modifier():
model = FakeToolCallingModel()
messages_modifier = "Foo"
agent_1 = create_react_agent(model, [], messages_modifier=messages_modifier)
agent_2 = create_react_agent(model, [], state_modifier=messages_modifier)
agent_1 = create_react_agent(
FakeToolCallingModel(), [], messages_modifier=messages_modifier
)
agent_2 = create_react_agent(
FakeToolCallingModel(), [], state_modifier=messages_modifier
)
for agent in [agent_1, agent_2]:
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {
"messages": inputs + [AIMessage(content="Foo-hi?", id="0")]
"messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])]
}
assert response == expected_response

Expand Down Expand Up @@ -584,3 +599,100 @@ def get_day_list(days: list[str]) -> list[str]:
[AIMessage(content="", tool_calls=tool_calls)]
)
assert outputs[0].content == json.dumps(data, ensure_ascii=False)


async def test_return_direct() -> None:
@dec_tool(return_direct=True)
def tool_return_direct(input: str) -> str:
"""A tool that returns directly."""
return f"Direct result: {input}"

@dec_tool
def tool_normal(input: str) -> str:
"""A normal tool."""
return f"Normal result: {input}"

first_tool_call = [
ToolCall(
name="tool_return_direct",
args={"input": "Test direct"},
id="1",
),
]
expected_ai = AIMessage(
content="Test direct",
id="0",
tool_calls=first_tool_call,
)
model = FakeToolCallingModel(tool_calls=[first_tool_call, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])

# Test direct return for tool_return_direct
result = agent.invoke(
{"messages": [HumanMessage(content="Test direct", id="hum0")]}
)
assert result["messages"] == [
HumanMessage(content="Test direct", id="hum0"),
expected_ai,
ToolMessage(
content="Direct result: Test direct",
name="tool_return_direct",
tool_call_id="1",
id=result["messages"][2].id,
),
]
second_tool_call = [
ToolCall(
name="tool_normal",
args={"input": "Test normal"},
id="2",
),
]
model = FakeToolCallingModel(tool_calls=[second_tool_call, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke(
{"messages": [HumanMessage(content="Test normal", id="hum1")]}
)
assert result["messages"] == [
HumanMessage(content="Test normal", id="hum1"),
AIMessage(content="Test normal", id="0", tool_calls=second_tool_call),
ToolMessage(
content="Normal result: Test normal",
name="tool_normal",
tool_call_id="2",
id=result["messages"][2].id,
),
AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"),
]

both_tool_calls = [
ToolCall(
name="tool_return_direct",
args={"input": "Test both direct"},
id="3",
),
ToolCall(
name="tool_normal",
args={"input": "Test both normal"},
id="4",
),
]
model = FakeToolCallingModel(tool_calls=[both_tool_calls, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]})
assert result["messages"] == [
HumanMessage(content="Test both", id="hum2"),
AIMessage(content="Test both", id="0", tool_calls=both_tool_calls),
ToolMessage(
content="Direct result: Test both direct",
name="tool_return_direct",
tool_call_id="3",
id=result["messages"][2].id,
),
ToolMessage(
content="Normal result: Test both normal",
name="tool_normal",
tool_call_id="4",
id=result["messages"][3].id,
),
]

0 comments on commit 10a66ac

Please sign in to comment.