Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(nyz): add evaluator more info viz support #538

Merged
merged 4 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ding/envs/env_manager/subprocess_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,8 +940,7 @@ def ready_obs(self) -> tnp.array:
)
time.sleep(0.001)
sleep_count += 1
obs = [self._ready_obs[i] for i in self.ready_env]
return tnp.stack(obs)
return tnp.stack([tnp.array(self._ready_obs[i]) for i in self.ready_env])

def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]:
"""
Expand Down
6 changes: 3 additions & 3 deletions ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import treetensor.torch as ttorch
from ding.envs import BaseEnvManager
from ding.policy import Policy
from ding.torch_utils import to_ndarray
from ding.torch_utils import to_ndarray, get_shape0

if TYPE_CHECKING:
from ding.framework import OnlineRLContext
Expand Down Expand Up @@ -74,8 +74,8 @@ def _inference(ctx: "OnlineRLContext"):
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
ctx.obs = obs
# TODO mask necessary rollout

obs = {i: obs[i] for i in range(obs.shape[0])} # TBD
num_envs = get_shape0(obs)
obs = {i: obs[i] for i in range(num_envs)} # TBD
inference_output = policy.forward(obs, **ctx.collect_kwargs)
ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD
ctx.inference_output = inference_output
Expand Down
106 changes: 94 additions & 12 deletions ding/framework/middleware/functional/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from ditk import logging
import numpy as np
import torch
import treetensor.numpy as tnp
import treetensor.torch as ttorch
from easydict import EasyDict
from ding.envs import BaseEnvManager
from ding.framework.context import OfflineRLContext
from ding.policy import Policy
from ding.data import Dataset, DataLoader
from ding.framework import task
from ding.torch_utils import tensor_to_list, to_ndarray
from ding.torch_utils import to_list, to_ndarray, get_shape0
from ding.utils import lists_to_dicts

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,7 +50,7 @@ class VectorEvalMonitor(object):
our average reward will have a bias and may not be accurate. we use VectorEvalMonitor to solve the problem.
Interfaces:
__init__, is_finished, update_info, update_reward, get_episode_reward, get_latest_reward, get_current_episode,\
get_episode_info
get_episode_info, update_video, get_episode_video
"""

def __init__(self, env_num: int, n_episode: int) -> None:
Expand All @@ -70,6 +71,14 @@ def __init__(self, env_num: int, n_episode: int) -> None:
each_env_episode[i] += 1
self._reward = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
self._info = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
self._video = {
env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
for env_id, maxlen in enumerate(each_env_episode)
}
self._output = {
env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
for env_id, maxlen in enumerate(each_env_episode)
}

def is_finished(self) -> bool:
"""
Expand All @@ -88,7 +97,6 @@ def update_info(self, env_id: int, info: Any) -> None:
- env_id: (:obj:`int`): the id of the environment we need to update information
- info: (:obj:`Any`): the information we need to update
"""
info = tensor_to_list(info)
self._info[env_id].append(info)

def update_reward(self, env_id: Union[int, np.ndarray], reward: Any) -> None:
Expand Down Expand Up @@ -136,24 +144,84 @@ def get_episode_info(self) -> dict:
if len(self._info[0]) == 0:
return None
else:
# sum among all envs
total_info = sum([list(v) for v in self._info.values()], [])
if isinstance(total_info[0], tnp.ndarray):
total_info = [t.json() for t in total_info]
total_info = lists_to_dicts(total_info)
new_dict = {}
for k in total_info.keys():
if np.isscalar(total_info[k][0]):
new_dict[k + '_mean'] = np.mean(total_info[k])
total_info.update(new_dict)
return total_info


def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> Callable:
try:
if np.isscalar(total_info[k][0].item()):
new_dict[k + '_mean'] = np.mean(total_info[k])
except: # noqa
pass
return new_dict

def _select_idx(self):
reward = [t.item() for t in self.get_episode_reward()]
sortarg = np.argsort(reward)
# worst, median(s), best
if len(sortarg) == 1:
idxs = [sortarg[0]]
elif len(sortarg) == 2:
idxs = [sortarg[0], sortarg[-1]]
elif len(sortarg) == 3:
idxs = [sortarg[0], sortarg[len(sortarg) // 2], sortarg[-1]]
else:
# TensorboardX pad the number of videos to even numbers with black frames,
# therefore providing even number of videos prevents black frames being rendered.
idxs = [sortarg[0], sortarg[len(sortarg) // 2 - 1], sortarg[len(sortarg) // 2], sortarg[-1]]
return idxs

def update_video(self, imgs):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
for env_id, img in imgs.items():
if len(self._reward[env_id]) == self._reward[env_id].maxlen:
continue
self._video[env_id][len(self._reward[env_id])].append(img)

def get_episode_video(self):
"""
Overview:
Convert list of videos into [N, T, C, H, W] tensor, containing
worst, median, best evaluation trajectories for video logging.
"""
videos = sum([list(v) for v in self._video.values()], [])
videos = [np.transpose(np.stack(video, 0), [0, 3, 1, 2]) for video in videos]
idxs = self._select_idx()
videos = [videos[idx] for idx in idxs]
# pad videos to the same length with last frames
max_length = max(video.shape[0] for video in videos)
for i in range(len(videos)):
if videos[i].shape[0] < max_length:
padding = np.tile([videos[i][-1]], (max_length - videos[i].shape[0], 1, 1, 1))
videos[i] = np.concatenate([videos[i], padding], 0)
videos = np.stack(videos, 0)
assert len(videos.shape) == 5, 'Need [N, T, C, H, W] input tensor for video logging!'
return videos

def update_output(self, output):
for env_id, o in output.items():
if len(self._reward[env_id]) == self._reward[env_id].maxlen:
continue
self._output[env_id][len(self._reward[env_id])].append(to_ndarray(o))

def get_episode_output(self):
output = sum([list(v) for v in self._output.values()], [])
idxs = self._select_idx()
output = [output[idx] for idx in idxs]
return output


def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False) -> Callable:
"""
Overview:
The middleware that executes the evaluation.
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be evaluated.
- env (:obj:`BaseEnvManager`): The env for the evaluation.
- render (:obj:`bool`): Whether to render env images and policy logits.
"""

env.seed(cfg.seed, dynamic_seed=False)
Expand Down Expand Up @@ -183,8 +251,12 @@ def _evaluate(ctx: "OnlineRLContext"):

while not eval_monitor.is_finished():
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
obs = {i: obs[i] for i in range(obs.shape[0])} # TBD
num_envs = get_shape0(obs)
obs = {i: obs[i] for i in range(num_envs)} # TBD
inference_output = policy.forward(obs)
if render:
eval_monitor.update_video(env.ready_imgs)
eval_monitor.update_output(inference_output)
output = [v for v in inference_output.values()]
action = [to_ndarray(v['action']) for v in output] # TBD
timesteps = env.step(action)
Expand All @@ -194,6 +266,8 @@ def _evaluate(ctx: "OnlineRLContext"):
policy.reset([env_id])
reward = timestep.info.final_eval_reward
eval_monitor.update_reward(env_id, reward)
if 'episode_info' in timestep.info:
eval_monitor.update_info(env_id, timestep.info.episode_info)
episode_reward = eval_monitor.get_episode_reward()
eval_reward = np.mean(episode_reward)
stop_flag = eval_reward >= cfg.env.stop_value and ctx.train_iter > 0
Expand All @@ -207,7 +281,15 @@ def _evaluate(ctx: "OnlineRLContext"):
)
ctx.last_eval_iter = ctx.train_iter
ctx.eval_value = eval_reward
ctx.eval_output = {'output': output, 'reward': episode_reward}
ctx.eval_output = {'reward': episode_reward}
episode_info = eval_monitor.get_episode_info()
if episode_info is not None:
ctx.eval_output['episode_info'] = episode_info
if render:
ctx.eval_output['replay_video'] = eval_monitor.get_episode_video()
ctx.eval_output['output'] = eval_monitor.get_episode_output()
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
else:
ctx.eval_output['output'] = output # for compatibility

if stop_flag:
task.finish = True
Expand Down
2 changes: 1 addition & 1 deletion ding/torch_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .checkpoint_helper import build_checkpoint_helper, CountVar, auto_checkpoint
from .data_helper import to_device, to_tensor, to_ndarray, to_list, to_dtype, same_shape, tensor_to_list, \
build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, squeeze, get_null_data
build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, squeeze, get_null_data, get_shape0
from .distribution import CategoricalPd, CategoricalPdPytorch
from .metric import levenshtein_distance, hamming_distance
from .network import *
Expand Down
18 changes: 18 additions & 0 deletions ding/torch_utils/data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import torch
import treetensor.torch as ttorch


def to_device(item: Any, device: str, ignore_keys: list = []) -> Any:
Expand Down Expand Up @@ -408,3 +409,20 @@ def get_null_data(template: Any, num: int) -> List[Any]:
data['reward'].zero_()
ret.append(data)
return ret


def get_shape0(data):
if isinstance(data, torch.Tensor):
return data.shape[0]
elif isinstance(data, ttorch.Tensor):

def fn(t):
item = list(t.values())[0]
if np.isscalar(item[0]):
return item[0]
else:
return fn(item)

return fn(data.shape)
else:
raise TypeError("not support type: {}".format(data))