Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect multiple subgraphs in single node #1793

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from types import MappingProxyType

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 58.9 ms +- 1.4 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (6.79 ms) is 12% of the mean (56.3 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. fanout_to_subgraph_10x_sync: Mean +- std dev: 56.3 ms +- 6.8 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 76.0 ms +- 1.1 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 80.7 ms +- 1.4 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 550 ms +- 19 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 500 ms +- 6 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 786 ms +- 36 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 783 ms +- 6 ms ......................................... react_agent_10x: Mean +- std dev: 41.9 ms +- 3.7 ms ......................................... react_agent_10x_sync: Mean +- std dev: 29.9 ms +- 0.6 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 53.8 ms +- 1.3 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 42.9 ms +- 3.5 ms ......................................... react_agent_100x: Mean +- std dev: 410 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 330 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 915 ms +- 10 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 817 ms +- 10 ms ......................................... wide_state_25x300: Mean +- std dev: 20.5 ms +- 0.3 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 13.0 ms +- 0.3 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 241 ms +- 8 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 240 ms +- 18 ms ......................................... wide_state_15x600: Mean +- std dev: 23.6 ms +- 0.6 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 14.7 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 416 ms +- 13 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 421 ms +- 21 ms ......................................... wide_state_9x1200: Mean +- std dev: 23.9 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 14.7 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 267 ms +- 7 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 265 ms +- 15 ms

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_10x_sync | 59.3 ms | 56.3 ms: 1.05x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 79.5 ms | 76.0 ms: 1.05x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 343 ms | 330 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 84.1 ms | 80.7 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 274 ms | 267 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 426 ms | 416 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 801 ms | 783 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 24.1 ms | 23.6 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x | 419 ms | 410 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 508 ms | 500 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 59.8 ms | 58.9 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 20.8 ms | 20.5 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 827 ms | 817 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 926 ms | 915 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 24.1 ms | 23.9 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 14.9 ms | 14.7 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 243 ms | 241 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 554 ms | 550 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 14.7 ms | 14.7 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 53.3 ms | 53.8 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 12.8 ms | 13.0 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 767 ms | 786 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint_sync | 234 ms | 240 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_10x | 39.6 ms | 41.9 ms: 1.06x slower | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.01x faster | +---------------------------------------
from typing import Any, Mapping

from langgraph.types import Interrupt, Send # noqa: F401
Expand All @@ -10,6 +10,14 @@
EMPTY_MAP: Mapping[str, Any] = MappingProxyType({})
EMPTY_SEQ: tuple[str, ...] = tuple()

# --- Public constants ---
TAG_HIDDEN = "langsmith:hidden"
# tag to hide a node/edge from certain tracing/streaming environments
START = "__start__"
# the first (maybe virtual) node in graph-style Pregel
END = "__end__"
# the last (maybe virtual) node in graph-style Pregel

# --- Reserved write keys ---
INPUT = "__input__"
# for values passed as input to the graph
Expand All @@ -23,10 +31,6 @@
# marker to signal node was scheduled (in distributed mode)
TASKS = "__pregel_tasks"
# for Send objects returned by nodes/edges, corresponds to PUSH below
START = "__start__"
# marker for the first (maybe virtual) node in graph-style Pregel
END = "__end__"
# marker for the last (maybe virtual) node in graph-style Pregel

# --- Reserved config.configurable keys ---
CONFIG_KEY_SEND = "__pregel_send"
Expand All @@ -43,8 +47,6 @@
# holds a `BaseStore` made available to managed values
CONFIG_KEY_RESUMING = "__pregel_resuming"
# holds a boolean indicating if subgraphs should resume from a previous checkpoint
CONFIG_KEY_GRAPH_COUNT = "__pregel_graph_count"
# holds the number of subgraphs executed in a given task, used to raise errors
CONFIG_KEY_TASK_ID = "__pregel_task_id"
# holds the task ID for the current task
CONFIG_KEY_DEDUPE_TASKS = "__pregel_dedupe_tasks"
Expand All @@ -68,14 +70,13 @@
# denotes pull-style tasks, ie. those triggered by edges
RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__"
# placeholder for managed values replaced at runtime
TAG_HIDDEN = "langsmith:hidden"
# tag to hide a node/edge from certain tracing/streaming environments
NS_SEP = "|"
# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph)
NS_END = ":"
# for checkpoint_ns, for each level, separates the namespace from the task_id

RESERVED = {
TAG_HIDDEN,
# reserved write keys
INPUT,
INTERRUPT,
Expand Down Expand Up @@ -103,7 +104,6 @@
PUSH,
PULL,
RUNTIME_PLACEHOLDER,
TAG_HIDDEN,
NS_SEP,
NS_END,
}
10 changes: 10 additions & 0 deletions libs/langgraph/langgraph/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,13 @@ class CheckpointNotLatest(Exception):
"""Raised when the checkpoint is not the latest version (for distributed mode)."""

pass


class MultipleSubgraphsError(Exception):
"""Raised when multiple subgraphs are called inside the same node."""

pass


_SEEN_CHECKPOINT_NS: set[str] = set()
"""Used for subgraph detection."""
5 changes: 2 additions & 3 deletions libs/langgraph/langgraph/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from typing_extensions import Self

from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import (
END,
NS_END,
Expand All @@ -39,7 +38,7 @@
from langgraph.pregel import Channel, Pregel
from langgraph.pregel.read import PregelNode
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.types import All
from langgraph.types import All, Checkpointer
from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -406,7 +405,7 @@ def validate(self, interrupt: Optional[Sequence[str]] = None) -> Self:

def compile(
self,
checkpointer: Optional[BaseCheckpointSaver] = None,
checkpointer: Checkpointer = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
interrupt_after: Optional[Union[All, list[str]]] = None,
debug: bool = False,
Expand Down
7 changes: 3 additions & 4 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.channels.named_barrier_value import NamedBarrierValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import NS_END, NS_SEP, TAG_HIDDEN
from langgraph.errors import InvalidUpdateError
from langgraph.graph.graph import END, START, Branch, CompiledGraph, Graph, Send
Expand All @@ -47,7 +46,7 @@
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import All, RetryPolicy
from langgraph.types import All, Checkpointer, RetryPolicy
from langgraph.utils.fields import get_field_default
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import coerce_to_runnable
Expand Down Expand Up @@ -400,7 +399,7 @@ def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> Self:

def compile(
self,
checkpointer: Optional[BaseCheckpointSaver] = None,
checkpointer: Checkpointer = None,
*,
store: Optional[BaseStore] = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
Expand All @@ -413,7 +412,7 @@ def compile(
streamed, batched, and run asynchronously.

Args:
checkpointer (Optional[BaseCheckpointSaver]): An optional checkpoint saver object.
checkpointer (Checkpointer): An optional checkpoint saver object.
This serves as a fully versioned "memory" for the graph, allowing
the graph to be paused and resumed, and replayed from any point.
interrupt_before (Optional[Sequence[str]]): An optional list of node names to interrupt before.
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from langchain_core.tools import BaseTool

from langgraph._api.deprecation import deprecated_parameter
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.types import Checkpointer


# We create the AgentState that we will pass around
Expand Down Expand Up @@ -132,7 +132,7 @@ def create_react_agent(
state_schema: Optional[StateSchemaType] = None,
messages_modifier: Optional[MessagesModifier] = None,
state_modifier: Optional[StateModifier] = None,
checkpointer: Optional[BaseCheckpointSaver] = None,
checkpointer: Checkpointer = None,
interrupt_before: Optional[list[str]] = None,
interrupt_after: Optional[list[str]] = None,
debug: bool = False,
Expand Down
32 changes: 13 additions & 19 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
from langgraph.pregel.validate import validate_graph, validate_keys
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import All, StateSnapshot, StreamMode
from langgraph.types import All, Checkpointer, StateSnapshot, StreamMode
from langgraph.utils.config import (
ensure_config,
merge_configs,
Expand Down Expand Up @@ -197,7 +197,7 @@ class Pregel(Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]]):
debug: bool
"""Whether to print debug information during execution. Defaults to False."""

checkpointer: Optional[BaseCheckpointSaver] = None
checkpointer: Checkpointer = None
"""Checkpointer used to save and load graph state. Defaults to None."""

store: Optional[BaseStore] = None
Expand Down Expand Up @@ -281,7 +281,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]:
[spec for node in self.nodes.values() for spec in node.config_specs]
+ (
self.checkpointer.config_specs
if self.checkpointer is not None
if isinstance(self.checkpointer, BaseCheckpointSaver)
else []
)
+ (
Expand Down Expand Up @@ -1059,6 +1059,8 @@ def _defaults(
Union[All, Sequence[str]],
Optional[BaseCheckpointSaver],
]:
if config["recursion_limit"] < 1:
raise ValueError("recursion_limit must be at least 1")
debug = debug if debug is not None else self.debug
if output_keys is None:
output_keys = self.stream_channels_asis
Expand All @@ -1072,12 +1074,16 @@ def _defaults(
if CONFIG_KEY_TASK_ID in config.get("configurable", {}):
# if being called as a node in another graph, always use values mode
stream_mode = ["values"]
if CONFIG_KEY_CHECKPOINTER in config.get("configurable", {}):
checkpointer: Optional[BaseCheckpointSaver] = config["configurable"][
CONFIG_KEY_CHECKPOINTER
]
if self.checkpointer is False:
checkpointer: Optional[BaseCheckpointSaver] = None
elif CONFIG_KEY_CHECKPOINTER in config.get("configurable", {}):
checkpointer = config["configurable"][CONFIG_KEY_CHECKPOINTER]
else:
checkpointer = self.checkpointer
if checkpointer and not config.get("configurable"):
raise ValueError(
f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in checkpointer.config_specs]}"
)
return (
debug,
set(stream_mode),
Expand Down Expand Up @@ -1193,12 +1199,6 @@ def output() -> Iterator:
run_id=config.get("run_id"),
)
try:
if config["recursion_limit"] < 1:
raise ValueError("recursion_limit must be at least 1")
if self.checkpointer and not config.get("configurable"):
raise ValueError(
f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in self.checkpointer.config_specs]}"
)
# assign defaults
(
debug,
Expand Down Expand Up @@ -1414,12 +1414,6 @@ def output() -> Iterator:
None,
)
try:
if config["recursion_limit"] < 1:
raise ValueError("recursion_limit must be at least 1")
if self.checkpointer and not config.get("configurable"):
raise ValueError(
f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in self.checkpointer.config_specs]}"
)
# assign defaults
(
debug,
Expand Down
12 changes: 12 additions & 0 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@
TASKS,
)
from langgraph.errors import (
_SEEN_CHECKPOINT_NS,
CheckpointNotLatest,
EmptyInputError,
GraphDelegate,
GraphInterrupt,
MultipleSubgraphsError,
)
from langgraph.managed.base import (
ManagedValueMapping,
Expand Down Expand Up @@ -195,6 +197,7 @@ def __init__(
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
output_keys: Union[str, Sequence[str]],
stream_keys: Union[str, Sequence[str]],
check_subgraphs: bool = True,
debug: bool = False,
) -> None:
self.stream = stream
Expand All @@ -220,6 +223,11 @@ def __init__(
self.config = patch_configurable(
self.config, {"checkpoint_ns": "", "checkpoint_id": None}
)
if check_subgraphs and self.is_nested and self.checkpointer is not None:
if self.config["configurable"]["checkpoint_ns"] in _SEEN_CHECKPOINT_NS:
raise MultipleSubgraphsError
else:
_SEEN_CHECKPOINT_NS.add(self.config["configurable"]["checkpoint_ns"])
if (
CONFIG_KEY_CHECKPOINT_MAP in self.config["configurable"]
and self.config["configurable"].get("checkpoint_ns")
Expand Down Expand Up @@ -634,6 +642,7 @@ def __init__(
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
output_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
stream_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
check_subgraphs: bool = True,
debug: bool = False,
) -> None:
super().__init__(
Expand All @@ -646,6 +655,7 @@ def __init__(
specs=specs,
output_keys=output_keys,
stream_keys=stream_keys,
check_subgraphs=check_subgraphs,
debug=debug,
)
self.stack = ExitStack()
Expand Down Expand Up @@ -755,6 +765,7 @@ def __init__(
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
output_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
stream_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
check_subgraphs: bool = True,
debug: bool = False,
) -> None:
super().__init__(
Expand All @@ -767,6 +778,7 @@ def __init__(
specs=specs,
output_keys=output_keys,
stream_keys=stream_keys,
check_subgraphs=check_subgraphs,
debug=debug,
)
self.store = AsyncBatchedStore(self.store) if self.store else None
Expand Down
26 changes: 18 additions & 8 deletions libs/langgraph/langgraph/pregel/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import time
from typing import Optional, Sequence

from langgraph.constants import CONFIG_KEY_GRAPH_COUNT, CONFIG_KEY_RESUMING
from langgraph.errors import GraphInterrupt
from langgraph.constants import CONFIG_KEY_RESUMING
from langgraph.errors import _SEEN_CHECKPOINT_NS, GraphInterrupt
from langgraph.types import PregelExecutableTask, RetryPolicy
from langgraph.utils.config import patch_configurable

Expand Down Expand Up @@ -70,9 +70,14 @@ def run_with_retry(
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(
config, {CONFIG_KEY_RESUMING: True, CONFIG_KEY_GRAPH_COUNT: 0}
)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
finally:
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)


async def arun_with_retry(
Expand Down Expand Up @@ -138,6 +143,11 @@ async def arun_with_retry(
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(
config, {CONFIG_KEY_RESUMING: True, CONFIG_KEY_GRAPH_COUNT: 0}
)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
finally:
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
18 changes: 16 additions & 2 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
from collections import deque
from dataclasses import dataclass
from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, Type, Union
from typing import (
Any,
Callable,
Literal,
NamedTuple,
Optional,
Sequence,
Type,
Union,
)

from langchain_core.runnables import Runnable, RunnableConfig

from langgraph.checkpoint.base import CheckpointMetadata
from langgraph.checkpoint.base import BaseCheckpointSaver, CheckpointMetadata

All = Literal["*"]
"""Special value to indicate that graph should interrupt on all nodes."""

Checkpointer = Union[None, Literal[False], BaseCheckpointSaver]
"""Type of the checkpointer to use for a subgraph. False disables checkpointing,
even if the parent graph has a checkpointer. None inherits checkpointer."""

StreamMode = Literal["values", "updates", "debug", "messages", "custom"]
"""How the stream method should emit outputs.
Expand Down
Loading
Loading