Skip to content

Commit

Permalink
feature(nyz): add eval policy output viz
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Nov 3, 2022
1 parent ef02fe7 commit 7f8c53e
Showing 1 changed file with 36 additions and 14 deletions.
50 changes: 36 additions & 14 deletions ding/framework/middleware/functional/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def __init__(self, env_num: int, n_episode: int) -> None:
env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
for env_id, maxlen in enumerate(each_env_episode)
}
self._output = {
env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
for env_id, maxlen in enumerate(each_env_episode)
}

def is_finished(self) -> bool:
"""
Expand Down Expand Up @@ -154,6 +158,22 @@ def get_episode_info(self) -> dict:
pass
return new_dict

def _select_idx(self):
reward = [t.item() for t in self.get_episode_reward()]
sortarg = np.argsort(reward)
# worst, median(s), best
if len(sortarg) == 1:
idxs = [sortarg[0]]
elif len(sortarg) == 2:
idxs = [sortarg[0], sortarg[-1]]
elif len(sortarg) == 3:
idxs = [sortarg[0], sortarg[len(sortarg) // 2], sortarg[-1]]
else:
# TensorboardX pad the number of videos to even numbers with black frames,
# therefore providing even number of videos prevents black frames being rendered.
idxs = [sortarg[0], sortarg[len(sortarg) // 2 - 1], sortarg[len(sortarg) // 2], sortarg[-1]]
return idxs

def update_video(self, imgs):
for env_id, img in imgs.items():
if len(self._reward[env_id]) == self._reward[env_id].maxlen:
Expand All @@ -168,19 +188,7 @@ def get_episode_video(self):
"""
videos = sum([list(v) for v in self._video.values()], [])
videos = [np.transpose(np.stack(video, 0), [0, 3, 1, 2]) for video in videos]
reward = [t.item() for t in self.get_episode_reward()]
sortarg = np.argsort(reward)
# worst, median(s), best
if len(sortarg) == 1:
idxs = [sortarg[0]]
elif len(sortarg) == 2:
idxs = [sortarg[0], sortarg[-1]]
elif len(sortarg) == 3:
idxs = [sortarg[0], sortarg[len(sortarg) // 2], sortarg[-1]]
else:
# TensorboardX pad the number of videos to even numbers with black frames,
# therefore providing even number of videos prevents black frames being rendered.
idxs = [sortarg[0], sortarg[len(sortarg) // 2 - 1], sortarg[len(sortarg) // 2], sortarg[-1]]
idxs = self._select_idx()
videos = [videos[idx] for idx in idxs]
# pad videos to the same length with last frames
max_length = max(video.shape[0] for video in videos)
Expand All @@ -192,6 +200,18 @@ def get_episode_video(self):
assert len(videos.shape) == 5, 'Need [N, T, C, H, W] input tensor for video logging!'
return videos

def update_output(self, output):
for env_id, o in output.items():
if len(self._reward[env_id]) == self._reward[env_id].maxlen:
continue
self._output[env_id][len(self._reward[env_id])].append(to_ndarray(o))

def get_episode_output(self):
output = sum([list(v) for v in self._output.values()], [])
idxs = self._select_idx()
output = [output[idx] for idx in idxs]
return output


def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False) -> Callable:
"""
Expand Down Expand Up @@ -232,9 +252,10 @@ def _evaluate(ctx: "OnlineRLContext"):
while not eval_monitor.is_finished():
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
obs = {i: obs[i] for i in range(obs.shape[0])} # TBD
inference_output = policy.forward(obs)
if render:
eval_monitor.update_video(env.ready_imgs)
inference_output = policy.forward(obs)
eval_monitor.update_output(inference_output)
output = [v for v in inference_output.values()]
action = [to_ndarray(v['action']) for v in output] # TBD
timesteps = env.step(action)
Expand Down Expand Up @@ -265,6 +286,7 @@ def _evaluate(ctx: "OnlineRLContext"):
ctx.eval_output['episode_info'] = episode_info
if render:
ctx.eval_output['replay_video'] = eval_monitor.get_episode_video()
ctx.eval_output['output'] = eval_monitor.get_episode_output()

if stop_flag:
task.finish = True
Expand Down

0 comments on commit 7f8c53e

Please sign in to comment.