From c85ec23d6b1e703d1d53658e7c7007fb10da3dcb Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Mon, 9 Sep 2024 15:43:44 -0700 Subject: [PATCH] pr comments --- ldp/data_structures.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ldp/data_structures.py b/ldp/data_structures.py index d96a5670..8eb344d2 100644 --- a/ldp/data_structures.py +++ b/ldp/data_structures.py @@ -4,6 +4,7 @@ import logging import os from typing import Any, ClassVar, Self, cast +from uuid import UUID import networkx as nx from aviary.message import Message @@ -132,7 +133,7 @@ def compute_discounted_returns(self, discount: float = 1.0) -> list[float]: class TransitionTree: - def __init__(self, root_id: str): + def __init__(self, root_id: str | UUID): """A tree of transitions. If A->B is an edge in this tree, then A and B are consecutive @@ -144,18 +145,18 @@ def __init__(self, root_id: str): All IDs of transitions added to this tree must begin with the same root_id. """ - self.root_id = root_id + self.root_id = str(root_id) self.tree = nx.DiGraph() # the actual tree self.rev_tree = nx.DiGraph() # the same as self.tree, but with reversed edges - self._add_node(root_id, transition=None) + self._add_node(self.root_id, transition=None) - def _add_node(self, step_id: str, transition: Transition | None): + def _add_node(self, step_id: str, transition: Transition | None) -> None: self.tree.add_node(step_id, transition=transition) self.rev_tree.add_node(step_id) - def _add_edge(self, parent_step_id: str, child_step_id: str): + def _add_edge(self, parent_step_id: str, child_step_id: str) -> None: self.tree.add_edge(parent_step_id, child_step_id) self.rev_tree.add_edge(child_step_id, parent_step_id) @@ -165,7 +166,7 @@ def get_transition(self, step_id: str) -> Transition: return cast(Transition, self.tree.nodes[step_id]["transition"]) - def add_transition(self, step_id: str, step: Transition): + def add_transition(self, step_id: str, step: Transition) -> None: """Add a transition to the tree. Args: @@ -243,7 +244,7 @@ def get_trajectories(self) -> list[Trajectory]: return trajs - def assign_mc_value_estimates(self, discount_factor: float = 1.0): + def assign_mc_value_estimates(self, discount_factor: float = 1.0) -> None: """Assign Monte Carlo state-action value estimates to each transition (in-place). Args: