Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan committed Sep 9, 2024
1 parent f73fbf5 commit c85ec23
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions ldp/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
from typing import Any, ClassVar, Self, cast
from uuid import UUID

import networkx as nx
from aviary.message import Message
Expand Down Expand Up @@ -132,7 +133,7 @@ def compute_discounted_returns(self, discount: float = 1.0) -> list[float]:


class TransitionTree:
def __init__(self, root_id: str):
def __init__(self, root_id: str | UUID):
"""A tree of transitions.
If A->B is an edge in this tree, then A and B are consecutive
Expand All @@ -144,18 +145,18 @@ def __init__(self, root_id: str):
All IDs of transitions added to this tree must begin with
the same root_id.
"""
self.root_id = root_id
self.root_id = str(root_id)

self.tree = nx.DiGraph() # the actual tree
self.rev_tree = nx.DiGraph() # the same as self.tree, but with reversed edges

self._add_node(root_id, transition=None)
self._add_node(self.root_id, transition=None)

def _add_node(self, step_id: str, transition: Transition | None):
def _add_node(self, step_id: str, transition: Transition | None) -> None:
self.tree.add_node(step_id, transition=transition)
self.rev_tree.add_node(step_id)

def _add_edge(self, parent_step_id: str, child_step_id: str):
def _add_edge(self, parent_step_id: str, child_step_id: str) -> None:
self.tree.add_edge(parent_step_id, child_step_id)
self.rev_tree.add_edge(child_step_id, parent_step_id)

Expand All @@ -165,7 +166,7 @@ def get_transition(self, step_id: str) -> Transition:

return cast(Transition, self.tree.nodes[step_id]["transition"])

def add_transition(self, step_id: str, step: Transition):
def add_transition(self, step_id: str, step: Transition) -> None:
"""Add a transition to the tree.
Args:
Expand Down Expand Up @@ -243,7 +244,7 @@ def get_trajectories(self) -> list[Trajectory]:

return trajs

def assign_mc_value_estimates(self, discount_factor: float = 1.0):
def assign_mc_value_estimates(self, discount_factor: float = 1.0) -> None:
"""Assign Monte Carlo state-action value estimates to each transition (in-place).
Args:
Expand Down

0 comments on commit c85ec23

Please sign in to comment.