Skip to content

Commit

Permalink
feature(nyz): add evaluator more info viz support (#538)
Browse files Browse the repository at this point in the history
* feature(nyz): add evaluator more info viz support

* feature(nyz): add eval policy output viz

* fix(wzl): fix bugs for dict obs in new pipeline (#541)

* fix(wzl): fix bugs for dict dobs in new pipeline

* fix(wzl): use get_shape0 func in new pipeline

* feature(wzl): add newline for style check

* polish(nyz): polish evaluator render implementation

Co-authored-by: zerlinwang <80957609+zerlinwang@users.noreply.github.com>
  • Loading branch information
PaParaZz1 and zerlinwang authored Nov 12, 2022
1 parent ce2e0f4 commit 721e671
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 18 deletions.
3 changes: 1 addition & 2 deletions ding/envs/env_manager/subprocess_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,8 +940,7 @@ def ready_obs(self) -> tnp.array:
)
time.sleep(0.001)
sleep_count += 1
obs = [self._ready_obs[i] for i in self.ready_env]
return tnp.stack(obs)
return tnp.stack([tnp.array(self._ready_obs[i]) for i in self.ready_env])

def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]:
"""
Expand Down
6 changes: 3 additions & 3 deletions ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import treetensor.torch as ttorch
from ding.envs import BaseEnvManager
from ding.policy import Policy
from ding.torch_utils import to_ndarray
from ding.torch_utils import to_ndarray, get_shape0

if TYPE_CHECKING:
from ding.framework import OnlineRLContext
Expand Down Expand Up @@ -74,8 +74,8 @@ def _inference(ctx: "OnlineRLContext"):
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
ctx.obs = obs
# TODO mask necessary rollout

obs = {i: obs[i] for i in range(obs.shape[0])} # TBD
num_envs = get_shape0(obs)
obs = {i: obs[i] for i in range(num_envs)} # TBD
inference_output = policy.forward(obs, **ctx.collect_kwargs)
ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD
ctx.inference_output = inference_output
Expand Down
106 changes: 94 additions & 12 deletions ding/framework/middleware/functional/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from ditk import logging
import numpy as np
import torch
import treetensor.numpy as tnp
import treetensor.torch as ttorch
from easydict import EasyDict
from ding.envs import BaseEnvManager
from ding.framework.context import OfflineRLContext
from ding.policy import Policy
from ding.data import Dataset, DataLoader
from ding.framework import task
from ding.torch_utils import tensor_to_list, to_ndarray
from ding.torch_utils import to_list, to_ndarray, get_shape0
from ding.utils import lists_to_dicts

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,7 +50,7 @@ class VectorEvalMonitor(object):
our average reward will have a bias and may not be accurate. we use VectorEvalMonitor to solve the problem.
Interfaces:
__init__, is_finished, update_info, update_reward, get_episode_reward, get_latest_reward, get_current_episode,\
get_episode_info
get_episode_info, update_video, get_episode_video
"""

def __init__(self, env_num: int, n_episode: int) -> None:
Expand All @@ -70,6 +71,14 @@ def __init__(self, env_num: int, n_episode: int) -> None:
each_env_episode[i] += 1
self._reward = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
self._info = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
self._video = {
env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
for env_id, maxlen in enumerate(each_env_episode)
}
self._output = {
env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
for env_id, maxlen in enumerate(each_env_episode)
}

def is_finished(self) -> bool:
"""
Expand All @@ -88,7 +97,6 @@ def update_info(self, env_id: int, info: Any) -> None:
- env_id: (:obj:`int`): the id of the environment we need to update information
- info: (:obj:`Any`): the information we need to update
"""
info = tensor_to_list(info)
self._info[env_id].append(info)

def update_reward(self, env_id: Union[int, np.ndarray], reward: Any) -> None:
Expand Down Expand Up @@ -136,24 +144,84 @@ def get_episode_info(self) -> dict:
if len(self._info[0]) == 0:
return None
else:
# sum among all envs
total_info = sum([list(v) for v in self._info.values()], [])
if isinstance(total_info[0], tnp.ndarray):
total_info = [t.json() for t in total_info]
total_info = lists_to_dicts(total_info)
new_dict = {}
for k in total_info.keys():
if np.isscalar(total_info[k][0]):
new_dict[k + '_mean'] = np.mean(total_info[k])
total_info.update(new_dict)
return total_info


def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> Callable:
try:
if np.isscalar(total_info[k][0].item()):
new_dict[k + '_mean'] = np.mean(total_info[k])
except: # noqa
pass
return new_dict

def _select_idx(self):
reward = [t.item() for t in self.get_episode_reward()]
sortarg = np.argsort(reward)
# worst, median(s), best
if len(sortarg) == 1:
idxs = [sortarg[0]]
elif len(sortarg) == 2:
idxs = [sortarg[0], sortarg[-1]]
elif len(sortarg) == 3:
idxs = [sortarg[0], sortarg[len(sortarg) // 2], sortarg[-1]]
else:
# TensorboardX pad the number of videos to even numbers with black frames,
# therefore providing even number of videos prevents black frames being rendered.
idxs = [sortarg[0], sortarg[len(sortarg) // 2 - 1], sortarg[len(sortarg) // 2], sortarg[-1]]
return idxs

def update_video(self, imgs):
for env_id, img in imgs.items():
if len(self._reward[env_id]) == self._reward[env_id].maxlen:
continue
self._video[env_id][len(self._reward[env_id])].append(img)

def get_episode_video(self):
"""
Overview:
Convert list of videos into [N, T, C, H, W] tensor, containing
worst, median, best evaluation trajectories for video logging.
"""
videos = sum([list(v) for v in self._video.values()], [])
videos = [np.transpose(np.stack(video, 0), [0, 3, 1, 2]) for video in videos]
idxs = self._select_idx()
videos = [videos[idx] for idx in idxs]
# pad videos to the same length with last frames
max_length = max(video.shape[0] for video in videos)
for i in range(len(videos)):
if videos[i].shape[0] < max_length:
padding = np.tile([videos[i][-1]], (max_length - videos[i].shape[0], 1, 1, 1))
videos[i] = np.concatenate([videos[i], padding], 0)
videos = np.stack(videos, 0)
assert len(videos.shape) == 5, 'Need [N, T, C, H, W] input tensor for video logging!'
return videos

def update_output(self, output):
for env_id, o in output.items():
if len(self._reward[env_id]) == self._reward[env_id].maxlen:
continue
self._output[env_id][len(self._reward[env_id])].append(to_ndarray(o))

def get_episode_output(self):
output = sum([list(v) for v in self._output.values()], [])
idxs = self._select_idx()
output = [output[idx] for idx in idxs]
return output


def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False) -> Callable:
"""
Overview:
The middleware that executes the evaluation.
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be evaluated.
- env (:obj:`BaseEnvManager`): The env for the evaluation.
- render (:obj:`bool`): Whether to render env images and policy logits.
"""

env.seed(cfg.seed, dynamic_seed=False)
Expand Down Expand Up @@ -183,8 +251,12 @@ def _evaluate(ctx: "OnlineRLContext"):

while not eval_monitor.is_finished():
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
obs = {i: obs[i] for i in range(obs.shape[0])} # TBD
num_envs = get_shape0(obs)
obs = {i: obs[i] for i in range(num_envs)} # TBD
inference_output = policy.forward(obs)
if render:
eval_monitor.update_video(env.ready_imgs)
eval_monitor.update_output(inference_output)
output = [v for v in inference_output.values()]
action = [to_ndarray(v['action']) for v in output] # TBD
timesteps = env.step(action)
Expand All @@ -194,6 +266,8 @@ def _evaluate(ctx: "OnlineRLContext"):
policy.reset([env_id])
reward = timestep.info.final_eval_reward
eval_monitor.update_reward(env_id, reward)
if 'episode_info' in timestep.info:
eval_monitor.update_info(env_id, timestep.info.episode_info)
episode_reward = eval_monitor.get_episode_reward()
eval_reward = np.mean(episode_reward)
stop_flag = eval_reward >= cfg.env.stop_value and ctx.train_iter > 0
Expand All @@ -207,7 +281,15 @@ def _evaluate(ctx: "OnlineRLContext"):
)
ctx.last_eval_iter = ctx.train_iter
ctx.eval_value = eval_reward
ctx.eval_output = {'output': output, 'reward': episode_reward}
ctx.eval_output = {'reward': episode_reward}
episode_info = eval_monitor.get_episode_info()
if episode_info is not None:
ctx.eval_output['episode_info'] = episode_info
if render:
ctx.eval_output['replay_video'] = eval_monitor.get_episode_video()
ctx.eval_output['output'] = eval_monitor.get_episode_output()
else:
ctx.eval_output['output'] = output # for compatibility

if stop_flag:
task.finish = True
Expand Down
2 changes: 1 addition & 1 deletion ding/torch_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .checkpoint_helper import build_checkpoint_helper, CountVar, auto_checkpoint
from .data_helper import to_device, to_tensor, to_ndarray, to_list, to_dtype, same_shape, tensor_to_list, \
build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, squeeze, get_null_data
build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, squeeze, get_null_data, get_shape0
from .distribution import CategoricalPd, CategoricalPdPytorch
from .metric import levenshtein_distance, hamming_distance
from .network import *
Expand Down
18 changes: 18 additions & 0 deletions ding/torch_utils/data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import torch
import treetensor.torch as ttorch


def to_device(item: Any, device: str, ignore_keys: list = []) -> Any:
Expand Down Expand Up @@ -408,3 +409,20 @@ def get_null_data(template: Any, num: int) -> List[Any]:
data['reward'].zero_()
ret.append(data)
return ret


def get_shape0(data):
if isinstance(data, torch.Tensor):
return data.shape[0]
elif isinstance(data, ttorch.Tensor):

def fn(t):
item = list(t.values())[0]
if np.isscalar(item[0]):
return item[0]
else:
return fn(item)

return fn(data.shape)
else:
raise TypeError("not support type: {}".format(data))

0 comments on commit 721e671

Please sign in to comment.