Skip to content

Commit

Permalink
Use a different strategy, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 22, 2024
1 parent 2fccd09 commit 81a9a2f
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 60 deletions.
18 changes: 9 additions & 9 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
EMPTY_MAP: Mapping[str, Any] = MappingProxyType({})
EMPTY_SEQ: tuple[str, ...] = tuple()

# --- Public constants ---
TAG_HIDDEN = "langsmith:hidden"
# tag to hide a node/edge from certain tracing/streaming environments
START = "__start__"
# the first (maybe virtual) node in graph-style Pregel
END = "__end__"
# the last (maybe virtual) node in graph-style Pregel

# --- Reserved write keys ---
INPUT = "__input__"
# for values passed as input to the graph
Expand All @@ -23,10 +31,6 @@
# marker to signal node was scheduled (in distributed mode)
TASKS = "__pregel_tasks"
# for Send objects returned by nodes/edges, corresponds to PUSH below
START = "__start__"
# marker for the first (maybe virtual) node in graph-style Pregel
END = "__end__"
# marker for the last (maybe virtual) node in graph-style Pregel

# --- Reserved config.configurable keys ---
CONFIG_KEY_SEND = "__pregel_send"
Expand All @@ -43,8 +47,6 @@
# holds a `BaseStore` made available to managed values
CONFIG_KEY_RESUMING = "__pregel_resuming"
# holds a boolean indicating if subgraphs should resume from a previous checkpoint
CONFIG_KEY_GRAPH_COUNT = "__pregel_graph_count"
# holds the number of subgraphs executed in a given task, used to raise errors
CONFIG_KEY_TASK_ID = "__pregel_task_id"
# holds the task ID for the current task
CONFIG_KEY_DEDUPE_TASKS = "__pregel_dedupe_tasks"
Expand All @@ -68,14 +70,13 @@
# denotes pull-style tasks, ie. those triggered by edges
RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__"
# placeholder for managed values replaced at runtime
TAG_HIDDEN = "langsmith:hidden"
# tag to hide a node/edge from certain tracing/streaming environments
NS_SEP = "|"
# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph)
NS_END = ":"
# for checkpoint_ns, for each level, separates the namespace from the task_id

RESERVED = {
TAG_HIDDEN,
# reserved write keys
INPUT,
INTERRUPT,
Expand Down Expand Up @@ -103,7 +104,6 @@
PUSH,
PULL,
RUNTIME_PLACEHOLDER,
TAG_HIDDEN,
NS_SEP,
NS_END,
}
10 changes: 10 additions & 0 deletions libs/langgraph/langgraph/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,13 @@ class CheckpointNotLatest(Exception):
"""Raised when the checkpoint is not the latest version (for distributed mode)."""

pass


class MultipleSubgraphsError(Exception):
"""Raised when multiple subgraphs are called inside the same node."""

pass


_SEEN_CHECKPOINT_NS: set[str] = set()
"""Used for subgraph detection."""
5 changes: 2 additions & 3 deletions libs/langgraph/langgraph/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from typing_extensions import Self

from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import (
END,
NS_END,
Expand All @@ -39,7 +38,7 @@
from langgraph.pregel import Channel, Pregel
from langgraph.pregel.read import PregelNode
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.types import All
from langgraph.types import All, Checkpointer
from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -406,7 +405,7 @@ def validate(self, interrupt: Optional[Sequence[str]] = None) -> Self:

def compile(
self,
checkpointer: Optional[BaseCheckpointSaver] = None,
checkpointer: Checkpointer = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
interrupt_after: Optional[Union[All, list[str]]] = None,
debug: bool = False,
Expand Down
7 changes: 3 additions & 4 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.channels.named_barrier_value import NamedBarrierValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import NS_END, NS_SEP, TAG_HIDDEN
from langgraph.errors import InvalidUpdateError
from langgraph.graph.graph import END, START, Branch, CompiledGraph, Graph, Send
Expand All @@ -47,7 +46,7 @@
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import All, RetryPolicy
from langgraph.types import All, Checkpointer, RetryPolicy
from langgraph.utils.fields import get_field_default
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import coerce_to_runnable
Expand Down Expand Up @@ -400,7 +399,7 @@ def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> Self:

def compile(
self,
checkpointer: Optional[BaseCheckpointSaver] = None,
checkpointer: Checkpointer = None,
*,
store: Optional[BaseStore] = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
Expand All @@ -413,7 +412,7 @@ def compile(
streamed, batched, and run asynchronously.
Args:
checkpointer (Optional[BaseCheckpointSaver]): An optional checkpoint saver object.
checkpointer (Checkpointer): An optional checkpoint saver object.
This serves as a fully versioned "memory" for the graph, allowing
the graph to be paused and resumed, and replayed from any point.
interrupt_before (Optional[Sequence[str]]): An optional list of node names to interrupt before.
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from langchain_core.tools import BaseTool

from langgraph._api.deprecation import deprecated_parameter
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.types import Checkpointer


# We create the AgentState that we will pass around
Expand Down Expand Up @@ -132,7 +132,7 @@ def create_react_agent(
state_schema: Optional[StateSchemaType] = None,
messages_modifier: Optional[MessagesModifier] = None,
state_modifier: Optional[StateModifier] = None,
checkpointer: Optional[BaseCheckpointSaver] = None,
checkpointer: Checkpointer = None,
interrupt_before: Optional[list[str]] = None,
interrupt_after: Optional[list[str]] = None,
debug: bool = False,
Expand Down
32 changes: 13 additions & 19 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
from langgraph.pregel.validate import validate_graph, validate_keys
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import All, StateSnapshot, StreamMode
from langgraph.types import All, Checkpointer, StateSnapshot, StreamMode
from langgraph.utils.config import (
ensure_config,
merge_configs,
Expand Down Expand Up @@ -197,7 +197,7 @@ class Pregel(Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]]):
debug: bool
"""Whether to print debug information during execution. Defaults to False."""

checkpointer: Optional[BaseCheckpointSaver] = None
checkpointer: Checkpointer = None
"""Checkpointer used to save and load graph state. Defaults to None."""

store: Optional[BaseStore] = None
Expand Down Expand Up @@ -281,7 +281,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]:
[spec for node in self.nodes.values() for spec in node.config_specs]
+ (
self.checkpointer.config_specs
if self.checkpointer is not None
if isinstance(self.checkpointer, BaseCheckpointSaver)
else []
)
+ (
Expand Down Expand Up @@ -1059,6 +1059,8 @@ def _defaults(
Union[All, Sequence[str]],
Optional[BaseCheckpointSaver],
]:
if config["recursion_limit"] < 1:
raise ValueError("recursion_limit must be at least 1")
debug = debug if debug is not None else self.debug
if output_keys is None:
output_keys = self.stream_channels_asis
Expand All @@ -1072,12 +1074,16 @@ def _defaults(
if CONFIG_KEY_TASK_ID in config.get("configurable", {}):
# if being called as a node in another graph, always use values mode
stream_mode = ["values"]
if CONFIG_KEY_CHECKPOINTER in config.get("configurable", {}):
checkpointer: Optional[BaseCheckpointSaver] = config["configurable"][
CONFIG_KEY_CHECKPOINTER
]
if self.checkpointer is False:
checkpointer: Optional[BaseCheckpointSaver] = None
elif CONFIG_KEY_CHECKPOINTER in config.get("configurable", {}):
checkpointer = config["configurable"][CONFIG_KEY_CHECKPOINTER]
else:
checkpointer = self.checkpointer
if checkpointer and not config.get("configurable"):
raise ValueError(
f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in checkpointer.config_specs]}"
)
return (
debug,
set(stream_mode),
Expand Down Expand Up @@ -1193,12 +1199,6 @@ def output() -> Iterator:
run_id=config.get("run_id"),
)
try:
if config["recursion_limit"] < 1:
raise ValueError("recursion_limit must be at least 1")
if self.checkpointer and not config.get("configurable"):
raise ValueError(
f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in self.checkpointer.config_specs]}"
)
# assign defaults
(
debug,
Expand Down Expand Up @@ -1414,12 +1414,6 @@ def output() -> Iterator:
None,
)
try:
if config["recursion_limit"] < 1:
raise ValueError("recursion_limit must be at least 1")
if self.checkpointer and not config.get("configurable"):
raise ValueError(
f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in self.checkpointer.config_specs]}"
)
# assign defaults
(
debug,
Expand Down
3 changes: 0 additions & 3 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from langgraph.constants import (
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_GRAPH_COUNT,
CONFIG_KEY_READ,
CONFIG_KEY_SEND,
CONFIG_KEY_TASK_ID,
Expand Down Expand Up @@ -430,7 +429,6 @@ def prepare_single_task(
manager.get_child(f"graph:step:{step}") if manager else None
),
configurable={
CONFIG_KEY_GRAPH_COUNT: 0,
CONFIG_KEY_TASK_ID: task_id,
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
Expand Down Expand Up @@ -541,7 +539,6 @@ def prepare_single_task(
else None
),
configurable={
CONFIG_KEY_GRAPH_COUNT: 0,
CONFIG_KEY_TASK_ID: task_id,
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
Expand Down
12 changes: 6 additions & 6 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
CONFIG_KEY_DEDUPE_TASKS,
CONFIG_KEY_DELEGATE,
CONFIG_KEY_ENSURE_LATEST,
CONFIG_KEY_GRAPH_COUNT,
CONFIG_KEY_RESUMING,
CONFIG_KEY_STREAM,
CONFIG_KEY_TASK_ID,
Expand All @@ -55,10 +54,12 @@
TASKS,
)
from langgraph.errors import (
_SEEN_CHECKPOINT_NS,
CheckpointNotLatest,
EmptyInputError,
GraphDelegate,
GraphInterrupt,
MultipleSubgraphsError,
)
from langgraph.managed.base import (
ManagedValueMapping,
Expand Down Expand Up @@ -221,12 +222,11 @@ def __init__(
self.config = patch_configurable(
self.config, {"checkpoint_ns": "", "checkpoint_id": None}
)
if self.is_nested:
if config["configurable"].get(CONFIG_KEY_GRAPH_COUNT, 0) > 0:
raise ValueError("Detected multiple subgraphs called in a single node.")
if self.is_nested and self.checkpointer is not None:
if self.config["configurable"]["checkpoint_ns"] in _SEEN_CHECKPOINT_NS:
raise MultipleSubgraphsError
else:
# mutate config so that sibling subgraphs can be detected
self.config["configurable"][CONFIG_KEY_GRAPH_COUNT] = 1
_SEEN_CHECKPOINT_NS.add(self.config["configurable"]["checkpoint_ns"])
if (
CONFIG_KEY_CHECKPOINT_MAP in self.config["configurable"]
and self.config["configurable"].get("checkpoint_ns")
Expand Down
26 changes: 18 additions & 8 deletions libs/langgraph/langgraph/pregel/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import time
from typing import Optional, Sequence

from langgraph.constants import CONFIG_KEY_GRAPH_COUNT, CONFIG_KEY_RESUMING
from langgraph.errors import GraphInterrupt
from langgraph.constants import CONFIG_KEY_RESUMING
from langgraph.errors import _SEEN_CHECKPOINT_NS, GraphInterrupt
from langgraph.types import PregelExecutableTask, RetryPolicy
from langgraph.utils.config import patch_configurable

Expand Down Expand Up @@ -70,9 +70,14 @@ def run_with_retry(
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(
config, {CONFIG_KEY_RESUMING: True, CONFIG_KEY_GRAPH_COUNT: 0}
)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
finally:
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)


async def arun_with_retry(
Expand Down Expand Up @@ -138,6 +143,11 @@ async def arun_with_retry(
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(
config, {CONFIG_KEY_RESUMING: True, CONFIG_KEY_GRAPH_COUNT: 0}
)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
finally:
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
18 changes: 16 additions & 2 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
from collections import deque
from dataclasses import dataclass
from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, Type, Union
from typing import (
Any,
Callable,
Literal,
NamedTuple,
Optional,
Sequence,
Type,
Union,
)

from langchain_core.runnables import Runnable, RunnableConfig

from langgraph.checkpoint.base import CheckpointMetadata
from langgraph.checkpoint.base import BaseCheckpointSaver, CheckpointMetadata

All = Literal["*"]
"""Special value to indicate that graph should interrupt on all nodes."""

Checkpointer = Union[None, Literal[False], BaseCheckpointSaver]
"""Type of the checkpointer to use for a subgraph. False disables checkpointing,
even if the parent graph has a checkpointer. None inherits checkpointer."""

StreamMode = Literal["values", "updates", "debug", "messages", "custom"]
"""How the stream method should emit outputs.
Expand Down
Loading

0 comments on commit 81a9a2f

Please sign in to comment.