Skip to content

Commit

Permalink
Polish code.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Dec 8, 2022
1 parent 71d1014 commit 264a321
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
6 changes: 3 additions & 3 deletions ding/framework/middleware/functional/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def update_reward(self, env_id: Union[int, np.ndarray], reward: Any) -> None:
def get_episode_return(self) -> list:
"""
Overview:
Get the total reward of one episode.
Sum up all reward and get the total return of one episode.
"""
return sum([list(v) for v in self._reward.values()], []) # sum(iterable, start)

Expand All @@ -139,7 +139,7 @@ def get_current_episode(self) -> int:
def get_episode_info(self) -> dict:
"""
Overview:
Get all episode information, such as total reward of one episode.
Get all episode information, such as total return of one episode.
"""
if len(self._info[0]) == 0:
return None
Expand All @@ -159,7 +159,7 @@ def get_episode_info(self) -> dict:
return new_dict

def _select_idx(self):
reward = [t.item() for t in self.get_episode_reward()]
reward = [t.item() for t in self.get_episode_return()]
sortarg = np.argsort(reward)
# worst, median(s), best
if len(sortarg) == 1:
Expand Down
14 changes: 7 additions & 7 deletions ding/worker/collector/base_serial_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def should_eval(self, train_iter: int) -> bool:

@abstractmethod
def eval(
self,
save_ckpt_fn: Callable = None,
train_iter: int = -1,
envstep: int = -1,
n_episode: Optional[int] = None
self,
save_ckpt_fn: Callable = None,
train_iter: int = -1,
envstep: int = -1,
n_episode: Optional[int] = None
) -> Any:
raise NotImplementedError

Expand Down Expand Up @@ -182,7 +182,7 @@ def get_video(self):
def get_episode_return(self) -> list:
"""
Overview:
Get the total reward of one episode.
Sum up all reward and get the total return of one episode.
"""
return sum([list(v) for v in self._reward.values()], []) # sum(iterable, start)

Expand All @@ -205,7 +205,7 @@ def get_current_episode(self) -> int:
def get_episode_info(self) -> dict:
"""
Overview:
Get all episode information, such as total reward of one episode.
Get all episode information, such as total return of one episode.
"""
if len(self._info[0]) == 0:
return None
Expand Down
2 changes: 1 addition & 1 deletion dizoo/beergame/config/beergame_onppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
gae_lambda=0.95,
collector=dict(
get_train_sample=True,
reward_shaping=True, # whether use total reward to reshape reward
reward_shaping=True, # whether use total return to reshape reward
),
),
eval=dict(evaluator=dict(eval_freq=500, )),
Expand Down
8 changes: 4 additions & 4 deletions dizoo/slime_volley/envs/test_slime_volley_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestSlimeVolley:

@pytest.mark.parametrize('agent_vs_agent', [True, False])
def test_slime_volley(self, agent_vs_agent):
total_rew = 0
total_return = 0
env = SlimeVolleyEnv(EasyDict({'env_id': 'SlimeVolley-v0', 'agent_vs_agent': agent_vs_agent}))
# env.enable_save_replay('replay_video')
obs1 = env.reset()
Expand All @@ -21,13 +21,13 @@ def test_slime_volley(self, agent_vs_agent):
action = env.random_action()
observations, rewards, done, infos = env.step(action)
if agent_vs_agent:
total_rew += rewards[0]
total_return += rewards[0]
else:
total_rew += rewards
total_return += rewards
obs1, obs2 = observations[0], observations[1]
assert obs1.shape == obs2.shape, (obs1.shape, obs2.shape)
if agent_vs_agent:
agent_lives, opponent_lives = infos[0]['ale.lives'], infos[1]['ale.lives']
if agent_vs_agent:
assert agent_lives == 0 or opponent_lives == 0, (agent_lives, opponent_lives)
print("total reward is:", total_rew)
print("total return is:", total_return)
2 changes: 1 addition & 1 deletion dizoo/smac/envs/test_smac_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def main(policy, map_name="3m", two_player=False):
draw += int(infos["draw"])

print(
"Total reward in episode {} = {} (me), {} (opponent). Me win {}, Draw {}, Opponent win {}, total {}."
"Total return in episode {} = {} (me), {} (opponent). Me win {}, Draw {}, Opponent win {}, total {}."
"".format(e, episode_return_me, episode_return_op, me_win, draw, op_win, e + 1)
)

Expand Down

0 comments on commit 264a321

Please sign in to comment.