Skip to content

Commit

Permalink
sdk: Update thread state methods for subgraphs
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 20, 2024
1 parent 8b95adb commit e3d2f23
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 72 deletions.
3 changes: 0 additions & 3 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10463,9 +10463,6 @@ def weather_node(state: SubGraphState, writer: StreamWriter):
class RouterState(MessagesState):
route: Literal["weather", "other"]

class Router(TypedDict):
route: Literal["weather", "other"]

router_model = FakeMessagesListChatModel(
responses=[
AIMessage(
Expand Down
117 changes: 50 additions & 67 deletions libs/sdk-py/langgraph_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from langgraph_sdk.schema import (
Assistant,
AssistantVersion,
Checkpoint,
Config,
Cron,
DisconnectMode,
Expand Down Expand Up @@ -804,13 +805,19 @@ async def copy(self, thread_id: str) -> None:
return await self.http.post(f"/threads/{thread_id}/copy", json=None)

async def get_state(
self, thread_id: str, checkpoint_id: Optional[str] = None
self,
thread_id: str,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None, # deprecated
*,
subgraphs: bool = False,
) -> ThreadState:
"""Get the state of a thread.
Args:
thread_id: The ID of the thread to get the state of.
checkpoint_id: The ID of the checkpoint to get the state of.
checkpoint: The checkpoint to get the state of.
subgraphs: Include subgraphs in the state.
Returns:
ThreadState: the thread of the state.
Expand Down Expand Up @@ -852,15 +859,12 @@ async def get_state(
]
},
'next': [],
'config':
'checkpoint':
{
'configurable':
{
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1'
}
},
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1'
}
'metadata':
{
'step': 1,
Expand All @@ -870,20 +874,20 @@ async def get_state(
{
'agent':
{
'messages': [
{
'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b',
'name': None,
'type': 'ai',
'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.",
'example': False,
'tool_calls': [],
'usage_metadata': None,
'additional_kwargs': {},
'response_metadata': {},
'invalid_tool_calls': []
}
]
'messages': [
{
'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b',
'name': None,
'type': 'ai',
'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.",
'example': False,
'tool_calls': [],
'usage_metadata': None,
'additional_kwargs': {},
'response_metadata': {},
'invalid_tool_calls': []
}
]
}
},
'user_id': None,
Expand All @@ -894,36 +898,45 @@ async def get_state(
'created_at': '2024-07-25T15:35:44.184703+00:00',
'parent_config':
{
'configurable':
{
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-d80d-6fa7-8000-9300467fad0f'
}
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-d80d-6fa7-8000-9300467fad0f'
}
}
""" # noqa: E501
if checkpoint_id:
return await self.http.get(f"/threads/{thread_id}/state/{checkpoint_id}")
if checkpoint:
return await self.http.post(
f"/threads/{thread_id}/state/checkpoint",
json={"checkpoint": checkpoint, "subgraphs": subgraphs},
)
elif checkpoint_id:
return await self.http.get(
f"/threads/{thread_id}/state/{checkpoint_id}",
params={"subgraphs": subgraphs},
)
else:
return await self.http.get(f"/threads/{thread_id}/state")
return await self.http.get(
f"/threads/{thread_id}/state",
params={"subgraphs": subgraphs},
)

async def update_state(
self,
thread_id: str,
values: dict,
*,
as_node: Optional[str] = None,
checkpoint_id: Optional[str] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None, # deprecated
) -> None:
"""Update the state of a thread.
Args:
thread_id: The ID of the thread to update.
values: The values to update to the state.
as_node: Update the state as if this node had just executed.
checkpoint_id: The ID of the checkpoint to update the state of.
checkpoint: The checkpoint to update the state of.
Returns:
None
Expand All @@ -934,7 +947,6 @@ async def update_state(
thread_id="my_thread_id",
values={"messages":[{"role": "user", "content": "hello!"}]},
as_node="my_node",
checkpoint_id="my_checkpoint_id"
)
""" # noqa: E501
Expand All @@ -943,41 +955,12 @@ async def update_state(
}
if checkpoint_id:
payload["checkpoint_id"] = checkpoint_id
if checkpoint:
payload["checkpoint"] = checkpoint
if as_node:
payload["as_node"] = as_node
return await self.http.post(f"/threads/{thread_id}/state", json=payload)

async def patch_state(
self,
thread_id: Union[str, Config],
metadata: dict,
) -> None:
"""Patch the state of a thread.
Args:
thread_id: The ID of the thread to get the state of.
metadata: The metadata to assign to the state.
Returns:
None
Example Usage:
await client.threads.patch_state(
thread_id="my_thread_id",
metadata={"name":"new_name"},
)
""" # noqa: E501
if isinstance(thread_id, dict):
thread_id_: str = thread_id["configurable"]["thread_id"]
else:
thread_id_ = thread_id
return await self.http.patch(
f"/threads/{thread_id_}/state",
json={"metadata": metadata},
)

async def get_history(
self,
thread_id: str,
Expand Down
12 changes: 10 additions & 2 deletions libs/sdk-py/langgraph_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ class Config(TypedDict, total=False):
"""


class Checkpoint(TypedDict):
"""Checkpoint model."""

checkpoint_id: str
checkpoint_ns: str
checkpoint_map: dict[str, Any]


class GraphSchema(TypedDict):
"""Graph model."""

Expand Down Expand Up @@ -110,13 +118,13 @@ class ThreadState(TypedDict):
next: Sequence[str]
"""The next nodes to execute. If empty, the thread is done until new input is
received."""
checkpoint_id: str
checkpoint: Checkpoint
"""The ID of the checkpoint."""
metadata: Json
"""Metadata for this state"""
created_at: Optional[str]
"""Timestamp of state creation"""
parent_checkpoint_id: Optional[str]
parent_checkpoint: Optional[Checkpoint]
"""The ID of the parent checkpoint. If missing, this is the root checkpoint."""


Expand Down

0 comments on commit e3d2f23

Please sign in to comment.