Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add return_direct support #1780

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 21 additions & 24 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/chat_agent_executor.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 59.8 ms +- 1.2 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (7.35 ms) is 13% of the mean (57.8 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. fanout_to_subgraph_10x_sync: Mean +- std dev: 57.8 ms +- 7.4 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 77.8 ms +- 1.4 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 82.0 ms +- 0.6 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 555 ms +- 9 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 505 ms +- 4 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 768 ms +- 25 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 792 ms +- 6 ms ......................................... react_agent_10x: Mean +- std dev: 39.3 ms +- 0.8 ms ......................................... react_agent_10x_sync: Mean +- std dev: 29.7 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 53.0 ms +- 1.3 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 43.2 ms +- 3.4 ms ......................................... react_agent_100x: Mean +- std dev: 415 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 344 ms +- 13 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 933 ms +- 9 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 832 ms +- 11 ms ......................................... wide_state_25x300: Mean +- std dev: 20.7 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 12.8 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 241 ms +- 9 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 230 ms +- 7 ms ......................................... wide_state_15x600: Mean +- std dev: 24.0 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 14.7 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 418 ms +- 11 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 412 ms +- 13 ms ......................................... wide_state_9x1200: Mean +- std dev: 23.9 ms +- 0.4 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 14.8 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 268 ms +- 8 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 260 ms +- 8 ms

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/chat_agent_executor.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +====================================+=========+=======================+ | react_agent_10x_checkpoint | 56.6 ms | 53.0 ms: 1.07x faster | +------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 881 ms | 832 ms: 1.06x faster | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 802 ms | 768 ms: 1.04x faster | +------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 44.7 ms | 43.2 ms: 1.04x faster | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 572 ms | 555 ms: 1.03x faster | +------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 30.6 ms | 29.7 ms: 1.03x faster | +------------------------------------+---------+-----------------------+ | wide_state_25x300 | 21.2 ms | 20.7 ms: 1.03x faster | +------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 959 ms | 933 ms: 1.03x faster | +------------------------------------+---------+-----------------------+ | react_agent_100x | 425 ms | 415 ms: 1.02x faster | +------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 13.1 ms | 12.8 ms: 1.02x faster | +------------------------------------+---------+-----------------------+ | react_agent_10x | 39.9 ms | 39.3 ms: 1.02x faster | +------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 350 ms | 344 ms: 1.02x faster | +------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 14.9 ms | 14.7 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 14.9 ms | 14.8 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 24.0 ms | 23.9 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 408 ms | 412 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.02x faster | +------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (12): fanout_to_subgraph_10x_sync, wide_state_9x1200_checkpoint, wide_state_15x600, wide_state_25x300_checkpoint, fanout_to_subgraph_100x_checkpoint_sync, wide_state_15x600_checkpoint, fanout_to_subgraph_10x_checkpoint, fanout_to_subgraph_10x, fanout_to_subgraph_100x_sync, fanout_to_subgraph_10x_checkpoint_sync, wide_state_9x1200_checkpoint_sync, wide_state_25x300_checkpoint_sync
Annotated,
Callable,
Literal,
Expand All @@ -11,17 +11,13 @@
)

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
SystemMessage,
)
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool

from langgraph._api.deprecation import deprecated_parameter
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep
Expand Down Expand Up @@ -430,15 +426,15 @@
model = model.bind_tools(tool_classes)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["continue", "end"]:
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return "end"
return "__end__"
# Otherwise if there is, we continue
else:
return "continue"
return "tools"

preprocessor = _get_model_preprocessing_runnable(state_modifier, messages_modifier)
model_runnable = preprocessor | model
Expand Down Expand Up @@ -498,23 +494,24 @@
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "tools",
# Otherwise we finish.
"end": END,
},
)

# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")
# If any of the tools are configured to return_directly after running,
# our graph needs to check if these were called
should_return_direct = {t.name for t in tool_classes if t.return_direct}

def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]:
for m in reversed(state["messages"]):
if not isinstance(m, ToolMessage):
break
if m.name in should_return_direct:
return "__end__"
return "agent"

if should_return_direct:
workflow.add_conditional_edges("tools", route_tool_responses)
else:
workflow.add_edge("tools", "agent")

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
Expand Down
6 changes: 2 additions & 4 deletions libs/langgraph/tests/__snapshots__/test_pregel.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -4857,13 +4857,11 @@
{
"source": "agent",
"target": "tools",
"data": "continue",
"conditional": true
},
{
"source": "agent",
"target": "__end__",
"data": "end",
"conditional": true
}
]
Expand All @@ -4875,8 +4873,8 @@
graph TD;
__start__ --> agent;
tools --> agent;
agent -.  continue  .-> tools;
agent -.  end  .-> __end__;
agent -.-> tools;
agent -.-> __end__;

'''
# ---
Expand Down
134 changes: 123 additions & 11 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@


class FakeToolCallingModel(BaseChatModel):
tool_calls: Optional[list[list[ToolCall]]] = None
index: int = 0

def _generate(
self,
messages: List[BaseMessage],
Expand All @@ -56,7 +59,15 @@ def _generate(
) -> ChatResult:
"""Top Level call"""
messages_string = "-".join([m.content for m in messages])
message = AIMessage(content=messages_string, id="0")
tool_calls = (
self.tool_calls[self.index % len(self.tool_calls)]
if self.tool_calls
else []
)
message = AIMessage(
content=messages_string, id=str(self.index), tool_calls=tool_calls.copy()
)
self.index += 1
return ChatResult(generations=[ChatGeneration(message=message)])

@property
Expand All @@ -68,8 +79,6 @@ def bind_tools(
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
if len(tools) > 0:
raise ValueError("Not supported yet!")
return self


Expand Down Expand Up @@ -144,29 +153,35 @@ def test_passing_two_modifiers():


def test_system_message_modifier():
model = FakeToolCallingModel()
messages_modifier = SystemMessage(content="Foo")
agent_1 = create_react_agent(model, [], messages_modifier=messages_modifier)
agent_2 = create_react_agent(model, [], state_modifier=messages_modifier)
agent_1 = create_react_agent(
FakeToolCallingModel(), [], messages_modifier=messages_modifier
)
agent_2 = create_react_agent(
FakeToolCallingModel(), [], state_modifier=messages_modifier
)
for agent in [agent_1, agent_2]:
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {
"messages": inputs + [AIMessage(content="Foo-hi?", id="0")]
"messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])]
}
assert response == expected_response


def test_system_message_string_modifier():
model = FakeToolCallingModel()
messages_modifier = "Foo"
agent_1 = create_react_agent(model, [], messages_modifier=messages_modifier)
agent_2 = create_react_agent(model, [], state_modifier=messages_modifier)
agent_1 = create_react_agent(
FakeToolCallingModel(), [], messages_modifier=messages_modifier
)
agent_2 = create_react_agent(
FakeToolCallingModel(), [], state_modifier=messages_modifier
)
for agent in [agent_1, agent_2]:
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {
"messages": inputs + [AIMessage(content="Foo-hi?", id="0")]
"messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])]
}
assert response == expected_response

Expand Down Expand Up @@ -584,3 +599,100 @@ def get_day_list(days: list[str]) -> list[str]:
[AIMessage(content="", tool_calls=tool_calls)]
)
assert outputs[0].content == json.dumps(data, ensure_ascii=False)


async def test_return_direct() -> None:
@dec_tool(return_direct=True)
def tool_return_direct(input: str) -> str:
"""A tool that returns directly."""
return f"Direct result: {input}"

@dec_tool
def tool_normal(input: str) -> str:
"""A normal tool."""
return f"Normal result: {input}"

first_tool_call = [
ToolCall(
name="tool_return_direct",
args={"input": "Test direct"},
id="1",
),
]
expected_ai = AIMessage(
content="Test direct",
id="0",
tool_calls=first_tool_call,
)
model = FakeToolCallingModel(tool_calls=[first_tool_call, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])

# Test direct return for tool_return_direct
result = agent.invoke(
{"messages": [HumanMessage(content="Test direct", id="hum0")]}
)
assert result["messages"] == [
HumanMessage(content="Test direct", id="hum0"),
expected_ai,
ToolMessage(
content="Direct result: Test direct",
name="tool_return_direct",
tool_call_id="1",
id=result["messages"][2].id,
),
]
second_tool_call = [
ToolCall(
name="tool_normal",
args={"input": "Test normal"},
id="2",
),
]
model = FakeToolCallingModel(tool_calls=[second_tool_call, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke(
{"messages": [HumanMessage(content="Test normal", id="hum1")]}
)
assert result["messages"] == [
HumanMessage(content="Test normal", id="hum1"),
AIMessage(content="Test normal", id="0", tool_calls=second_tool_call),
ToolMessage(
content="Normal result: Test normal",
name="tool_normal",
tool_call_id="2",
id=result["messages"][2].id,
),
AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"),
]

both_tool_calls = [
ToolCall(
name="tool_return_direct",
args={"input": "Test both direct"},
id="3",
),
ToolCall(
name="tool_normal",
args={"input": "Test both normal"},
id="4",
),
]
model = FakeToolCallingModel(tool_calls=[both_tool_calls, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]})
assert result["messages"] == [
HumanMessage(content="Test both", id="hum2"),
AIMessage(content="Test both", id="0", tool_calls=both_tool_calls),
ToolMessage(
content="Direct result: Test both direct",
name="tool_return_direct",
tool_call_id="3",
id=result["messages"][2].id,
),
ToolMessage(
content="Normal result: Test both normal",
name="tool_normal",
tool_call_id="4",
id=result["messages"][3].id,
),
]
Loading