Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tree search refactor #18

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 56 additions & 45 deletions ldp/alg/tree_search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
import itertools
import logging
import uuid
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, cast
from typing import Any

from aviary.message import Message
from aviary.utils import is_coroutine_callable
Expand All @@ -18,7 +17,7 @@
TEnv,
reraise_exc_as,
)
from ldp.data_structures import Trajectory
from ldp.data_structures import TransitionTree

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,93 +54,105 @@ async def sample_trees(
self,
environments: Sequence[TEnv],
max_depth: int | None = None,
) -> list[list[Trajectory]]:
) -> list[TransitionTree]:
return await asyncio.gather(*[
self.sample_tree(env, max_depth) for env in environments
])

async def sample_tree(self, env: TEnv, max_depth: int | None) -> list[Trajectory]:
async def sample_tree(self, env: TEnv, max_depth: int | None) -> TransitionTree:
max_depth_f = max_depth if max_depth is not None else float("inf")
tree = TransitionTree(root_id=str(uuid.uuid4()))

try:
with reraise_exc_as(EnvError, enabled=self.catch_env_failures):
obs, tools = await env.reset()

with reraise_exc_as(AgentError, enabled=self.catch_agent_failures):
agent_state = await self.agent.init_state(tools)

root_traj = Trajectory(traj_id=str(uuid.uuid4()))
return await self._descend(root_traj, env, agent_state, obs, max_depth_f)

except CaughtError:
return []
return tree

await self._descend(
tree=tree,
prev_step_id=tree.root_id,
env=env,
agent_state=agent_state,
obs=obs,
prev_timestep=-1,
prev_cumulative_reward=0.0,
max_depth=max_depth_f,
)

return tree

async def _descend(
self,
branch: Trajectory,
tree: TransitionTree,
prev_step_id: str,
env: TEnv,
agent_state: Any,
obs: list[Message],
prev_timestep: int,
prev_cumulative_reward: float,
max_depth: float,
) -> list[Trajectory]:
) -> None:
# Descend one level in the tree, by adding branching_factor children to the branch
# Then, recurse on each child
root_traj_id = cast(str, branch.traj_id).split(":")[0]
if root_traj_id in self.target_reward_hit:
return [branch]

timestep = len(branch.steps)
if tree.root_id in self.target_reward_hit:
# If at least one branch hit the target reward, stop descending
return

async def inner_descend(idx: int) -> list[Trajectory]:
timestep = prev_timestep + 1

async def inner_descend(idx: int) -> None:
if is_coroutine_callable(self.env_clone_fn):
cloned_env = await self.env_clone_fn(env) # type: ignore[arg-type, misc]
else:
cloned_env = self.env_clone_fn(env) # type: ignore[arg-type]

# Descend one step
traj_id = f"{branch.traj_id}:{idx}"
step_id = f"{prev_step_id}:{idx}"
try:
step = await self._take_step(
timestep, traj_id, cloned_env, agent_state, obs
timestep, step_id, cloned_env, agent_state, obs
)
except CaughtError:
# If we failed, do not extend the branch - just return an empty list
return []
# If we failed, do not extend the branch - just give up on this path
return

await asyncio.gather(*[
callback.after_transition(traj_id, self.agent, cloned_env, step)
callback.after_transition(step_id, self.agent, cloned_env, step)
for callback in self.callbacks
])

# The original branch plus one step
extended_branch = Trajectory(traj_id=traj_id, steps=[*branch.steps, step])
tree.add_transition(step_id, step)

if step.done:
return

if (
step.done # Trajectory is over
or len(extended_branch.steps) >= max_depth # Hit max depth
):
return [extended_branch]
if timestep + 1 >= max_depth:
step.truncated = True
return

if (
sum(step_.reward for step_ in extended_branch.steps)
>= self.target_reward
):
cumulative_reward = prev_cumulative_reward + step.reward
if cumulative_reward >= self.target_reward:
# signal other descents to stop too
self.target_reward_hit.add(root_traj_id)
return [extended_branch]
self.target_reward_hit.add(tree.root_id)

# Recurse
return await self._descend(
extended_branch,
cloned_env,
step.next_agent_state,
step.next_observation,
max_depth,
await self._descend(
tree=tree,
prev_step_id=step_id,
env=cloned_env,
agent_state=step.next_agent_state,
obs=step.next_observation,
prev_timestep=timestep,
prev_cumulative_reward=cumulative_reward,
max_depth=max_depth,
)

# Add branching_factory children
branches = await asyncio.gather(*[
# Add branching_factor children
await asyncio.gather(*[
inner_descend(idx) for idx in range(self.branching_factor)
])

return list(itertools.chain.from_iterable(branches))
143 changes: 142 additions & 1 deletion ldp/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import json
import logging
import os
from typing import Any, ClassVar, Self
from typing import Any, ClassVar, Self, cast
from uuid import UUID

import networkx as nx
from aviary.message import Message
from aviary.tools import ToolRequestMessage, ToolResponseMessage
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
Expand Down Expand Up @@ -128,3 +130,142 @@ def compute_discounted_returns(self, discount: float = 1.0) -> list[float]:
terminated=[step.truncated for step in self.steps],
discount=discount,
)


class TransitionTree:
def __init__(self, root_id: str | UUID):
"""A tree of transitions.

If A->B is an edge in this tree, then A and B are consecutive
transitions in an LDP. Any path from the root node to a terminal
node constitutes a complete LDP.

Args:
root_id: A unique identifier for the root node of the tree.
All IDs of transitions added to this tree must begin with
the same root_id.
"""
self.root_id = str(root_id)

self.tree = nx.DiGraph() # the actual tree
self.rev_tree = nx.DiGraph() # the same as self.tree, but with reversed edges

self._add_node(self.root_id, transition=None)

def _add_node(self, step_id: str, transition: Transition | None) -> None:
self.tree.add_node(step_id, transition=transition)
self.rev_tree.add_node(step_id)

def _add_edge(self, parent_step_id: str, child_step_id: str) -> None:
self.tree.add_edge(parent_step_id, child_step_id)
self.rev_tree.add_edge(child_step_id, parent_step_id)

def get_transition(self, step_id: str) -> Transition:
if step_id == self.root_id:
raise ValueError("Root node has no transition.")

return cast(Transition, self.tree.nodes[step_id]["transition"])

def add_transition(self, step_id: str, step: Transition) -> None:
"""Add a transition to the tree.

Args:
step_id: A unique identifier for the root node of the tree.
The expected form of the step ID is "{parent step ID}:{step index}".
step: The transition to add.
"""
root_id, *step_ids = step_id.split(":")
assert (
root_id == self.root_id
), f"Step ID {step_id} does not start with root ID {self.root_id}"
assert step_ids, "Step ID cannot be the same as the root ID."
# TODO: maybe this should be warning?
assert (
step_id not in self.tree
), f"Step ID {step_id} already exists in the tree."

self._add_node(step_id, transition=step)

parent_id = ":".join([root_id, *step_ids[:-1]])
if parent_id in self.tree:
self._add_edge(parent_id, step_id)

def get_trajectories(self) -> list[Trajectory]:
"""Return a list of trajectories.

Since each path from the root node to a terminal node defines
a unique trajectory, N(terminal node) trajectories will be returned.
The trajectory ID will be set to the ID of the terminal step.

Note that we include failed and truncated trajectories; it is up to the
caller to decide what to do them.

Returns:
All trajectories in this tree.
"""
trajs = []
step: Transition | None

for step_id, step in self.tree.nodes(data="transition"):
if not step:
# root node
continue

is_terminal = (
# check terminal conditions in increasing order of expense
step.done
or step.truncated
or step.failed
or self.tree.out_degree(step_id) == 0
)

if not is_terminal:
continue

# set the ID to the terminal node, which uniquely identifies the path
traj = Trajectory(traj_id=step_id)
# Build the trajectory up from a terminal node
current_step: Transition | None = step
current_step_id = step_id

# Walk backwards towards the root (current_step=None)
while current_step:
traj.steps.append(current_step)

parent_step_id, *extra = list(self.rev_tree.successors(current_step_id))
assert not extra, f"Expected a single parent, but got {len(extra) + 1}"

current_step_id = parent_step_id
current_step = self.tree.nodes[parent_step_id]["transition"]

# would've added things in reverse order, so fix that here
traj.steps.sort(key=lambda x: x.timestep)
trajs.append(traj)

return trajs

def assign_mc_value_estimates(self, discount_factor: float = 1.0) -> None:
"""Assign Monte Carlo state-action value estimates to each transition (in-place).

Args:
discount_factor: The discount factor to use when computing cumulative
future rewards.
"""
for step_id in nx.topological_sort(self.rev_tree):
step: Transition | None = self.tree.nodes[step_id]["transition"]
if step is None:
continue

if children := list(self.tree.successors(step_id)):
# V_{t+1}(s') = sum_{a'} p(a'|s') * Q_{t+1}(s', a')
# Here we assume p(a'|s') is uniform.
# TODO: don't make that assumption where a logprob is available
v_tp1 = sum(
self.get_transition(child_id).value for child_id in children
) / len(children)
else:
v_tp1 = 0.0

# Q_t(s_t, a_t) = r_{t+1} + gamma * V_{t+1}(s_{t+1})
# (we are assuming the environment is deterministic)
step.value = step.reward + discount_factor * v_tp1
Loading
Loading