From f73fbf5a4d7e29b1c4544f92cf7c21db39303ff7 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Mon, 9 Sep 2024 15:19:45 -0700 Subject: [PATCH] adding proper tree data structure for tree search --- ldp/alg/tree_search.py | 101 ++++++++++++++++------------- ldp/data_structures.py | 142 ++++++++++++++++++++++++++++++++++++++++- tests/test_rollouts.py | 55 +++++++++++++++- 3 files changed, 250 insertions(+), 48 deletions(-) diff --git a/ldp/alg/tree_search.py b/ldp/alg/tree_search.py index 27b532c2..f1b81aa2 100644 --- a/ldp/alg/tree_search.py +++ b/ldp/alg/tree_search.py @@ -1,9 +1,8 @@ import asyncio -import itertools import logging import uuid from collections.abc import Awaitable, Callable, Sequence -from typing import Any, cast +from typing import Any from aviary.message import Message from aviary.utils import is_coroutine_callable @@ -18,7 +17,7 @@ TEnv, reraise_exc_as, ) -from ldp.data_structures import Trajectory +from ldp.data_structures import TransitionTree logger = logging.getLogger(__name__) @@ -55,13 +54,14 @@ async def sample_trees( self, environments: Sequence[TEnv], max_depth: int | None = None, - ) -> list[list[Trajectory]]: + ) -> list[TransitionTree]: return await asyncio.gather(*[ self.sample_tree(env, max_depth) for env in environments ]) - async def sample_tree(self, env: TEnv, max_depth: int | None) -> list[Trajectory]: + async def sample_tree(self, env: TEnv, max_depth: int | None) -> TransitionTree: max_depth_f = max_depth if max_depth is not None else float("inf") + tree = TransitionTree(root_id=str(uuid.uuid4())) try: with reraise_exc_as(EnvError, enabled=self.catch_env_failures): @@ -69,79 +69,90 @@ async def sample_tree(self, env: TEnv, max_depth: int | None) -> list[Trajectory with reraise_exc_as(AgentError, enabled=self.catch_agent_failures): agent_state = await self.agent.init_state(tools) - - root_traj = Trajectory(traj_id=str(uuid.uuid4())) - return await self._descend(root_traj, env, agent_state, obs, max_depth_f) - except CaughtError: - return [] + return tree + + await self._descend( + tree=tree, + prev_step_id=tree.root_id, + env=env, + agent_state=agent_state, + obs=obs, + prev_timestep=-1, + prev_cumulative_reward=0.0, + max_depth=max_depth_f, + ) + + return tree async def _descend( self, - branch: Trajectory, + tree: TransitionTree, + prev_step_id: str, env: TEnv, agent_state: Any, obs: list[Message], + prev_timestep: int, + prev_cumulative_reward: float, max_depth: float, - ) -> list[Trajectory]: + ) -> None: # Descend one level in the tree, by adding branching_factor children to the branch # Then, recurse on each child - root_traj_id = cast(str, branch.traj_id).split(":")[0] - if root_traj_id in self.target_reward_hit: - return [branch] - timestep = len(branch.steps) + if tree.root_id in self.target_reward_hit: + # If at least one branch hit the target reward, stop descending + return - async def inner_descend(idx: int) -> list[Trajectory]: + timestep = prev_timestep + 1 + + async def inner_descend(idx: int) -> None: if is_coroutine_callable(self.env_clone_fn): cloned_env = await self.env_clone_fn(env) # type: ignore[arg-type, misc] else: cloned_env = self.env_clone_fn(env) # type: ignore[arg-type] # Descend one step - traj_id = f"{branch.traj_id}:{idx}" + step_id = f"{prev_step_id}:{idx}" try: step = await self._take_step( - timestep, traj_id, cloned_env, agent_state, obs + timestep, step_id, cloned_env, agent_state, obs ) except CaughtError: - # If we failed, do not extend the branch - just return an empty list - return [] + # If we failed, do not extend the branch - just give up on this path + return await asyncio.gather(*[ - callback.after_transition(traj_id, self.agent, cloned_env, step) + callback.after_transition(step_id, self.agent, cloned_env, step) for callback in self.callbacks ]) - # The original branch plus one step - extended_branch = Trajectory(traj_id=traj_id, steps=[*branch.steps, step]) + tree.add_transition(step_id, step) + + if step.done: + return - if ( - step.done # Trajectory is over - or len(extended_branch.steps) >= max_depth # Hit max depth - ): - return [extended_branch] + if timestep + 1 >= max_depth: + step.truncated = True + return - if ( - sum(step_.reward for step_ in extended_branch.steps) - >= self.target_reward - ): + cumulative_reward = prev_cumulative_reward + step.reward + if cumulative_reward >= self.target_reward: # signal other descents to stop too - self.target_reward_hit.add(root_traj_id) - return [extended_branch] + self.target_reward_hit.add(tree.root_id) # Recurse - return await self._descend( - extended_branch, - cloned_env, - step.next_agent_state, - step.next_observation, - max_depth, + await self._descend( + tree=tree, + prev_step_id=step_id, + env=cloned_env, + agent_state=step.next_agent_state, + obs=step.next_observation, + prev_timestep=timestep, + prev_cumulative_reward=cumulative_reward, + max_depth=max_depth, ) - # Add branching_factory children - branches = await asyncio.gather(*[ + # Add branching_factor children + await asyncio.gather(*[ inner_descend(idx) for idx in range(self.branching_factor) ]) - - return list(itertools.chain.from_iterable(branches)) diff --git a/ldp/data_structures.py b/ldp/data_structures.py index 8273e257..d96a5670 100644 --- a/ldp/data_structures.py +++ b/ldp/data_structures.py @@ -3,8 +3,9 @@ import json import logging import os -from typing import Any, ClassVar, Self +from typing import Any, ClassVar, Self, cast +import networkx as nx from aviary.message import Message from aviary.tools import ToolRequestMessage, ToolResponseMessage from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator @@ -128,3 +129,142 @@ def compute_discounted_returns(self, discount: float = 1.0) -> list[float]: terminated=[step.truncated for step in self.steps], discount=discount, ) + + +class TransitionTree: + def __init__(self, root_id: str): + """A tree of transitions. + + If A->B is an edge in this tree, then A and B are consecutive + transitions in an LDP. Any path from the root node to a terminal + node constitutes a complete LDP. + + Args: + root_id: A unique identifier for the root node of the tree. + All IDs of transitions added to this tree must begin with + the same root_id. + """ + self.root_id = root_id + + self.tree = nx.DiGraph() # the actual tree + self.rev_tree = nx.DiGraph() # the same as self.tree, but with reversed edges + + self._add_node(root_id, transition=None) + + def _add_node(self, step_id: str, transition: Transition | None): + self.tree.add_node(step_id, transition=transition) + self.rev_tree.add_node(step_id) + + def _add_edge(self, parent_step_id: str, child_step_id: str): + self.tree.add_edge(parent_step_id, child_step_id) + self.rev_tree.add_edge(child_step_id, parent_step_id) + + def get_transition(self, step_id: str) -> Transition: + if step_id == self.root_id: + raise ValueError("Root node has no transition.") + + return cast(Transition, self.tree.nodes[step_id]["transition"]) + + def add_transition(self, step_id: str, step: Transition): + """Add a transition to the tree. + + Args: + step_id: A unique identifier for the root node of the tree. + The expected form of the step ID is "{parent step ID}:{step index}". + step: The transition to add. + """ + root_id, *step_ids = step_id.split(":") + assert ( + root_id == self.root_id + ), f"Step ID {step_id} does not start with root ID {self.root_id}" + assert step_ids, "Step ID cannot be the same as the root ID." + # TODO: maybe this should be warning? + assert ( + step_id not in self.tree + ), f"Step ID {step_id} already exists in the tree." + + self._add_node(step_id, transition=step) + + parent_id = ":".join([root_id, *step_ids[:-1]]) + if parent_id in self.tree: + self._add_edge(parent_id, step_id) + + def get_trajectories(self) -> list[Trajectory]: + """Return a list of trajectories. + + Since each path from the root node to a terminal node defines + a unique trajectory, N(terminal node) trajectories will be returned. + The trajectory ID will be set to the ID of the terminal step. + + Note that we include failed and truncated trajectories; it is up to the + caller to decide what to do them. + + Returns: + All trajectories in this tree. + """ + trajs = [] + step: Transition | None + + for step_id, step in self.tree.nodes(data="transition"): + if not step: + # root node + continue + + is_terminal = ( + # check terminal conditions in increasing order of expense + step.done + or step.truncated + or step.failed + or self.tree.out_degree(step_id) == 0 + ) + + if not is_terminal: + continue + + # set the ID to the terminal node, which uniquely identifies the path + traj = Trajectory(traj_id=step_id) + # Build the trajectory up from a terminal node + current_step: Transition | None = step + current_step_id = step_id + + # Walk backwards towards the root (current_step=None) + while current_step: + traj.steps.append(current_step) + + parent_step_id, *extra = list(self.rev_tree.successors(current_step_id)) + assert not extra, f"Expected a single parent, but got {len(extra) + 1}" + + current_step_id = parent_step_id + current_step = self.tree.nodes[parent_step_id]["transition"] + + # would've added things in reverse order, so fix that here + traj.steps.sort(key=lambda x: x.timestep) + trajs.append(traj) + + return trajs + + def assign_mc_value_estimates(self, discount_factor: float = 1.0): + """Assign Monte Carlo state-action value estimates to each transition (in-place). + + Args: + discount_factor: The discount factor to use when computing cumulative + future rewards. + """ + for step_id in nx.topological_sort(self.rev_tree): + step: Transition | None = self.tree.nodes[step_id]["transition"] + if step is None: + continue + + if children := list(self.tree.successors(step_id)): + # V_{t+1}(s') = sum_{a'} p(a'|s') * Q_{t+1}(s', a') + # Here we assume p(a'|s') is uniform. + # TODO: don't make that assumption where a logprob is available + v_tp1 = sum( + self.get_transition(child_id).value for child_id in children + ) / len(children) + else: + v_tp1 = 0.0 + + # Q_t(s_t, a_t) = r_{t+1} + gamma * V_{t+1}(s_{t+1}) + # (we are assuming the environment is deterministic) + step.value = step.reward + discount_factor * v_tp1 diff --git a/tests/test_rollouts.py b/tests/test_rollouts.py index c3ae628f..59c5189d 100644 --- a/tests/test_rollouts.py +++ b/tests/test_rollouts.py @@ -1,3 +1,4 @@ +import itertools import random import tempfile from copy import deepcopy @@ -14,7 +15,7 @@ from ldp.alg.callbacks import Callback from ldp.alg.rollout import RolloutManager from ldp.alg.tree_search import TreeSearchRollout -from ldp.data_structures import Trajectory, Transition +from ldp.data_structures import Trajectory, Transition, TransitionTree from ldp.graph.common_ops import FxnOp from ldp.graph.op_utils import compute_graph, set_training_mode from ldp.graph.ops import OpResult @@ -240,9 +241,16 @@ async def test_tree_search(): concurrency_limit=1, callbacks=[callback], ) - trajs = await rollout_manager.sample_tree(env, max_depth=3) + tree = await rollout_manager.sample_tree(env, max_depth=3) + trajs = tree.get_trajectories() assert len(trajs) == 8 + traj_ids_wo_root = { + cast(str, traj.traj_id).replace(tree.root_id, "").lstrip(":") for traj in trajs + } + # IDs should be 0:0:0, 0:0:1, ... 1:1:1 (order doesn't matter) + assert traj_ids_wo_root == {":".join(x) for x in itertools.product("01", repeat=3)} + observations = {} # type: ignore[var-annotated] for traj in trajs: branch_path = tuple(cast(str, traj.traj_id).split(":")[1:]) @@ -266,3 +274,46 @@ async def test_tree_search(): # - branching factor = 2, depth = 3 # - root node isn't sampled, so no i=0 term in sum assert all(v == 14 for v in callback.fn_invocations.values()) + + +def test_tree_mc_value(): + root_id = "dummy" + tree = TransitionTree(root_id=root_id) + + kw = { + "agent_state": None, + "next_agent_state": None, + "observation": Transition.NO_OBSERVATION, + "next_observation": Transition.NO_OBSERVATION, + "action": None, + } + + # Construct a tree with some rewards scattered about + tree.add_transition(f"{root_id}:0", Transition(timestep=0, reward=0.0, **kw)) + + tree.add_transition(f"{root_id}:0:0", Transition(timestep=1, reward=1.0, **kw)) + for i in range(3): + tree.add_transition( + f"{root_id}:0:0:{i}", + Transition(timestep=2, reward=float(i), done=True, **kw), + ) + + tree.add_transition( + f"{root_id}:0:1", Transition(timestep=1, reward=-1.0, done=True, **kw) + ) + + tree.assign_mc_value_estimates(discount_factor=0.9) + + # Now make sure the value estimates are as expected + # First, check the terminal nodes: Q==reward + for i in range(3): + assert tree.get_transition(f"{root_id}:0:0:{i}").value == float(i) + assert tree.get_transition(f"{root_id}:0:1").value == -1.0 + + # Then go up the tree + assert tree.get_transition(f"{root_id}:0:0").value == pytest.approx( + 1.9, rel=0.001 + ) # 1 + 0.9 * avg(0, 1, 2) + assert tree.get_transition(f"{root_id}:0").value == pytest.approx( + 0.405, rel=0.001 + ) # 0 + 0.9 * avg(1.9, -1)