Skip to content

Commit

Permalink
fix env_sampler eval info list issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Jinyu Wang authored and Jinyu Wang committed Oct 27, 2023
1 parent 607d3b6 commit c19038a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions maro/rl/rollout/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def sample(

return {
"experiences": [total_experiences],
"info": [deepcopy(self._info)], # TODO: may have overhead issues. Leave to future work.
"info": [deepcopy(self._info)],
}

def set_policy_state(self, policy_state_dict: Dict[str, dict]) -> None:
Expand Down Expand Up @@ -592,7 +592,7 @@ def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int
self._step(list(env_action_dict.values()))
cache_element.next_state = self._state

if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()?
if self._reward_eval_delay is None:
self._calc_reward(cache_element)
self._post_eval_step(cache_element)

Expand All @@ -606,7 +606,7 @@ def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int
self._calc_reward(cache_element)
self._post_eval_step(cache_element)

info_list.append(self._info)
info_list.append(deepcopy(self._info))

return {"info": info_list}

Expand Down

0 comments on commit c19038a

Please sign in to comment.