Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: sven1977 <svenmika1977@gmail.com>
  • Loading branch information
sven1977 committed Aug 28, 2024
1 parent f92a620 commit d931a86
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion rllib/utils/metrics/metrics_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def set_state(self, state: Dict[str, Any]) -> None:
for flat_key, stats_state in state["stats"].items():
self._set_key(flat_key, Stats.from_state(stats_state))

def _check_tensor(self, key, value) -> None:
def _check_tensor(self, key: Tuple[str], value) -> None:
# `value` is a tensor -> Log it in our keys set.
if self.tensor_mode and (
(torch and torch.is_tensor(value)) or (tf and tf.is_tensor(value))
Expand Down Expand Up @@ -938,6 +938,12 @@ def _set_key(self, flat_key, stats):

def _del_key(self, flat_key, key_error=False):
flat_key = force_tuple(tree.flatten(flat_key))

# Erase the tensor key as well, if applicable.
if flat_key in self._tensor_keys:
self._tensor_keys.discard(flat_key)

# Erase the key from the (nested) `self.stats` dict.
_dict = self.stats
try:
for i, key in enumerate(flat_key):
Expand Down

0 comments on commit d931a86

Please sign in to comment.