Skip to content

Commit

Permalink
feature(nyz): add evaluator more info viz support
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Nov 3, 2022
1 parent 40c9bca commit ef02fe7
Showing 1 changed file with 66 additions and 9 deletions.
75 changes: 66 additions & 9 deletions ding/framework/middleware/functional/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from ditk import logging
import numpy as np
import torch
import treetensor.numpy as tnp
import treetensor.torch as ttorch
from easydict import EasyDict
from ding.envs import BaseEnvManager
from ding.framework.context import OfflineRLContext
from ding.policy import Policy
from ding.data import Dataset, DataLoader
from ding.framework import task
from ding.torch_utils import tensor_to_list, to_ndarray
from ding.torch_utils import to_list, to_ndarray
from ding.utils import lists_to_dicts

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,6 +71,10 @@ def __init__(self, env_num: int, n_episode: int) -> None:
each_env_episode[i] += 1
self._reward = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
self._info = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
self._video = {
env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
for env_id, maxlen in enumerate(each_env_episode)
}

def is_finished(self) -> bool:
"""
Expand All @@ -88,7 +93,6 @@ def update_info(self, env_id: int, info: Any) -> None:
- env_id: (:obj:`int`): the id of the environment we need to update information
- info: (:obj:`Any`): the information we need to update
"""
info = tensor_to_list(info)
self._info[env_id].append(info)

def update_reward(self, env_id: Union[int, np.ndarray], reward: Any) -> None:
Expand Down Expand Up @@ -136,24 +140,68 @@ def get_episode_info(self) -> dict:
if len(self._info[0]) == 0:
return None
else:
# sum among all envs
total_info = sum([list(v) for v in self._info.values()], [])
if isinstance(total_info[0], tnp.ndarray):
total_info = [t.json() for t in total_info]
total_info = lists_to_dicts(total_info)
new_dict = {}
for k in total_info.keys():
if np.isscalar(total_info[k][0]):
new_dict[k + '_mean'] = np.mean(total_info[k])
total_info.update(new_dict)
return total_info


def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> Callable:
try:
if np.isscalar(total_info[k][0].item()):
new_dict[k + '_mean'] = np.mean(total_info[k])
except: # noqa
pass
return new_dict

def update_video(self, imgs):
for env_id, img in imgs.items():
if len(self._reward[env_id]) == self._reward[env_id].maxlen:
continue
self._video[env_id][len(self._reward[env_id])].append(img)

def get_episode_video(self):
"""
Overview:
Convert list of videos into [N, T, C, H, W] tensor, containing
worst, median, best evaluation trajectories for video logging.
"""
videos = sum([list(v) for v in self._video.values()], [])
videos = [np.transpose(np.stack(video, 0), [0, 3, 1, 2]) for video in videos]
reward = [t.item() for t in self.get_episode_reward()]
sortarg = np.argsort(reward)
# worst, median(s), best
if len(sortarg) == 1:
idxs = [sortarg[0]]
elif len(sortarg) == 2:
idxs = [sortarg[0], sortarg[-1]]
elif len(sortarg) == 3:
idxs = [sortarg[0], sortarg[len(sortarg) // 2], sortarg[-1]]
else:
# TensorboardX pad the number of videos to even numbers with black frames,
# therefore providing even number of videos prevents black frames being rendered.
idxs = [sortarg[0], sortarg[len(sortarg) // 2 - 1], sortarg[len(sortarg) // 2], sortarg[-1]]
videos = [videos[idx] for idx in idxs]
# pad videos to the same length with last frames
max_length = max(video.shape[0] for video in videos)
for i in range(len(videos)):
if videos[i].shape[0] < max_length:
padding = np.tile([videos[i][-1]], (max_length - videos[i].shape[0], 1, 1, 1))
videos[i] = np.concatenate([videos[i], padding], 0)
videos = np.stack(videos, 0)
assert len(videos.shape) == 5, 'Need [N, T, C, H, W] input tensor for video logging!'
return videos


def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False) -> Callable:
"""
Overview:
The middleware that executes the evaluation.
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be evaluated.
- env (:obj:`BaseEnvManager`): The env for the evaluation.
- render (:obj:`bool`): Whether to render env images.
"""

env.seed(cfg.seed, dynamic_seed=False)
Expand Down Expand Up @@ -184,6 +232,8 @@ def _evaluate(ctx: "OnlineRLContext"):
while not eval_monitor.is_finished():
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
obs = {i: obs[i] for i in range(obs.shape[0])} # TBD
if render:
eval_monitor.update_video(env.ready_imgs)
inference_output = policy.forward(obs)
output = [v for v in inference_output.values()]
action = [to_ndarray(v['action']) for v in output] # TBD
Expand All @@ -194,6 +244,8 @@ def _evaluate(ctx: "OnlineRLContext"):
policy.reset([env_id])
reward = timestep.info.final_eval_reward
eval_monitor.update_reward(env_id, reward)
if 'episode_info' in timestep.info:
eval_monitor.update_info(env_id, timestep.info.episode_info)
episode_reward = eval_monitor.get_episode_reward()
eval_reward = np.mean(episode_reward)
stop_flag = eval_reward >= cfg.env.stop_value and ctx.train_iter > 0
Expand All @@ -208,6 +260,11 @@ def _evaluate(ctx: "OnlineRLContext"):
ctx.last_eval_iter = ctx.train_iter
ctx.eval_value = eval_reward
ctx.eval_output = {'output': output, 'reward': episode_reward}
episode_info = eval_monitor.get_episode_info()
if episode_info is not None:
ctx.eval_output['episode_info'] = episode_info
if render:
ctx.eval_output['replay_video'] = eval_monitor.get_episode_video()

if stop_flag:
task.finish = True
Expand Down

0 comments on commit ef02fe7

Please sign in to comment.