Skip to content

Commit

Permalink
polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Feb 4, 2024
1 parent 21062c4 commit e28b1c5
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 142 deletions.
19 changes: 9 additions & 10 deletions ding/example/qgpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import trainer, CkptSaver, offline_logger, wandb_offline_logger, termination_checker
from ding.framework.middleware.functional.evaluator import qgpo_interaction_evaluator
from ding.framework.middleware.functional.evaluator import interaction_evaluator
from ding.framework.middleware.functional.data_processor import qgpo_support_data_generator, qgpo_offline_data_fetcher
from ding.utils import set_pkg_seed

Expand Down Expand Up @@ -131,19 +131,18 @@ def main():
policy_state_dict = torch.load(cfg.policy.load_path, map_location=torch.device("cpu"))
policy.learn_mode.load_state_dict(policy_state_dict)

evaluator_env = BaseEnvManagerV2(
env_fn=[
lambda: DingEnvWrapper(env=gym.make(cfg.env.env_id), cfg=cfg.env, caller="evaluator")
for _ in range(cfg.env.evaluator_env_num)
],
cfg=cfg.env.manager
)

task.use(qgpo_support_data_generator(cfg, dataset, policy))
task.use(qgpo_offline_data_fetcher(cfg, dataset, collate_fn=None))
task.use(trainer(cfg, policy.learn_mode))
for guidance_scale in cfg.policy.eval.guidance_scale:
task.use(qgpo_interaction_evaluator(cfg, guidance_scale, policy.eval_mode, evaluator_env))
evaluator_env = BaseEnvManagerV2(
env_fn=[
lambda: DingEnvWrapper(env=gym.make(cfg.env.env_id), cfg=cfg.env, caller="evaluator")
for _ in range(cfg.env.evaluator_env_num)
],
cfg=cfg.env.manager
)
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env, guidance_scale=guidance_scale))
task.use(
wandb_offline_logger(
cfg=EasyDict(
Expand Down
2 changes: 2 additions & 0 deletions ding/framework/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class OnlineRLContext(Context):
last_eval_value: int = -np.inf
eval_output: List = dataclasses.field(default_factory=dict)
# wandb
info_for_logging: Dict = dataclasses.field(default_factory=dict)
wandb_url: str = ""

def __post_init__(self):
Expand All @@ -93,6 +94,7 @@ class OfflineRLContext(Context):
last_eval_value: int = -np.inf
eval_output: List = dataclasses.field(default_factory=dict)
# wandb
info_for_logging: Dict = dataclasses.field(default_factory=dict)
wandb_url: str = ""

def __post_init__(self):
Expand Down
155 changes: 30 additions & 125 deletions ding/framework/middleware/functional/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def get_episode_output(self):
return output


def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False) -> Callable:
def interaction_evaluator(
cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False, **kwargs
) -> Callable:
"""
Overview:
The middleware that executes the evaluation.
Expand All @@ -219,6 +221,7 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, re
- policy (:obj:`Policy`): The policy to be evaluated.
- env (:obj:`BaseEnvManager`): The env for the evaluation.
- render (:obj:`bool`): Whether to render env images and policy logits.
- kwargs: (:obj:`Any`): Other arguments for specific evaluation.
"""
if task.router.is_active and not task.has_role(task.role.EVALUATOR):
return task.void()
Expand All @@ -239,8 +242,13 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):

# evaluation will be executed if the task begins or enough train_iter after last evaluation
if ctx.last_eval_iter != -1 and \
(ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq):
return
(ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq):
if ctx.train_iter != ctx.last_eval_iter:
return
if len(kwargs) > 0:
kwargs_str = '/'.join([f'{k}({v})' for k, v in kwargs.items()])
else:
kwargs_str = ''

if env.closed:
env.launch()
Expand All @@ -252,7 +260,10 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
while not eval_monitor.is_finished():
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD
inference_output = policy.forward(obs)
if len(kwargs) > 0:
inference_output = policy.forward(obs, **kwargs)
else:
inference_output = policy.forward(obs)
if render:
eval_monitor.update_video(env.ready_imgs)
eval_monitor.update_output(inference_output)
Expand All @@ -275,12 +286,14 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
stop_flag = episode_return >= cfg.env.stop_value and ctx.train_iter > 0
if isinstance(ctx, OnlineRLContext):
logging.info(
'Evaluation: Train Iter({})\tEnv Step({})\tEpisode Return({:.3f})'.format(
ctx.train_iter, ctx.env_step, episode_return
'Evaluation: Train Iter({}) Env Step({}) Episode Return({:.3f}) {}'.format(
ctx.train_iter, ctx.env_step, episode_return, kwargs_str
)
)
elif isinstance(ctx, OfflineRLContext):
logging.info('Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, episode_return))
logging.info(
'Evaluation: Train Iter({}) Eval Return({:.3f}) {}'.format(ctx.train_iter, episode_return, kwargs_str)
)
else:
raise TypeError("not supported ctx type: {}".format(type(ctx)))
ctx.last_eval_iter = ctx.train_iter
Expand All @@ -299,6 +312,16 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
else:
ctx.eval_output['output'] = output # for compatibility

if len(kwargs) > 0:
ctx.info_for_logging.update(
{
f'{kwargs_str}/eval_episode_return': episode_return,
f'{kwargs_str}/eval_episode_return_min': episode_return_min,
f'{kwargs_str}/eval_episode_return_max': episode_return_max,
f'{kwargs_str}/eval_episode_return_std': episode_return_std,
}
)

if stop_flag:
task.finish = True

Expand Down Expand Up @@ -433,122 +456,4 @@ def _evaluate(ctx: "Context"):
return _evaluate


def qgpo_interaction_evaluator(
cfg: EasyDict, guidance_scale, policy: Policy, env: BaseEnvManager, render: bool = False
) -> Callable:
"""
Overview:
The middleware that executes the evaluation.
Arguments:
- cfg (:obj:`EasyDict`): Config.
- guidance_scale (:obj:`float`): The guidance scale for evaluation.
- policy (:obj:`Policy`): The policy to be evaluated.
- env (:obj:`BaseEnvManager`): The env for the evaluation.
- render (:obj:`bool`): Whether to render env images and policy logits.
"""
if task.router.is_active and not task.has_role(task.role.EVALUATOR):
return task.void()

env.seed(cfg.seed, dynamic_seed=False)

def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
"""
Overview:
- The evaluation will be executed if the task begins and enough train_iter passed \
since last evaluation.
Input of ctx:
- last_eval_iter (:obj:`int`): Last evaluation iteration.
- train_iter (:obj:`int`): Current train iteration.
Output of ctx:
- eval_value (:obj:`float`): The average reward in the current evaluation.
"""

# evaluation will be executed if the task begins or enough train_iter after last evaluation
if ctx.last_eval_iter != -1 and \
(ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq):
if ctx.train_iter != ctx.last_eval_iter:
return

ctx.info_for_logging = {}

if env.closed:
env.launch()
else:
env.reset()
policy.reset()
eval_monitor = VectorEvalMonitor(env.env_num, cfg.env.n_evaluator_episode)

while not eval_monitor.is_finished():
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD
data = {'s': obs, 'guidance_scale': guidance_scale}
inference_output = policy.forward(data)
if render:
eval_monitor.update_video(env.ready_imgs)
eval_monitor.update_output(inference_output)
output = [v for v in inference_output.values()]
action = [to_ndarray(v['action']) for v in output] # TBD
timesteps = env.step(action)
for timestep in timesteps:
env_id = timestep.env_id.item()
if timestep.done:
policy.reset([env_id])
reward = timestep.info.eval_episode_return
eval_monitor.update_reward(env_id, reward)
if 'episode_info' in timestep.info:
eval_monitor.update_info(env_id, timestep.info.episode_info)
episode_return = eval_monitor.get_episode_return()

episode_return_min = np.min(episode_return)
episode_return_max = np.max(episode_return)
episode_return_std = np.std(episode_return)
episode_return = np.mean(episode_return)
stop_flag = episode_return >= cfg.env.stop_value and ctx.train_iter > 0
if isinstance(ctx, OnlineRLContext):
logging.info(
'Evaluation: Train Iter({})\tEnv Step({})\tEpisode Return({:.3f})\tguidance_scale({})'.format(
ctx.train_iter, ctx.env_step, episode_return, guidance_scale
)
)
elif isinstance(ctx, OfflineRLContext):
logging.info(
'Evaluation: Train Iter({})\tEval Reward({:.3f})\tguidance_scale({})'.format(
ctx.train_iter, episode_return, guidance_scale
)
)
else:
raise TypeError("not supported ctx type: {}".format(type(ctx)))
ctx.last_eval_iter = ctx.train_iter
ctx.eval_value = episode_return if not hasattr(ctx, 'eval_value') or ctx.eval_value < episode_return else 0
ctx.eval_value_min = min(episode_return_min,
ctx.eval_value_min) if hasattr(ctx, 'eval_value_min') else episode_return_min
ctx.eval_value_max = max(episode_return_max,
ctx.eval_value_max) if hasattr(ctx, 'eval_value_max') else episode_return_max
ctx.eval_value_std = max(episode_return_std,
ctx.eval_value_std) if hasattr(ctx, 'eval_value_std') else episode_return_std
ctx.last_eval_value = ctx.eval_value
ctx.eval_output = {'episode_return': episode_return}
episode_info = eval_monitor.get_episode_info()
if episode_info is not None:
ctx.eval_output['episode_info'] = episode_info
if render:
ctx.eval_output['replay_video'] = eval_monitor.get_episode_video()
ctx.eval_output['output'] = eval_monitor.get_episode_output()
else:
ctx.eval_output['output'] = output # for compatibility
ctx.info_for_logging.update(
{
f'guidance_scale[{guidance_scale}]/eval_episode_return': episode_return,
f'guidance_scale[{guidance_scale}]/eval_episode_return_min': episode_return_min,
f'guidance_scale[{guidance_scale}]/eval_episode_return_max': episode_return_max,
f'guidance_scale[{guidance_scale}]/eval_episode_return_std': episode_return_std,
}
)

if stop_flag:
task.finish = True

return _evaluate


# TODO battle evaluator
6 changes: 6 additions & 0 deletions ding/framework/middleware/functional/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,12 @@ def _plot(ctx: "OfflineRLContext"):

if ctx.eval_value != -np.inf:
if hasattr(ctx, "info_for_logging"):
"""
.. note::
The info_for_logging is a dict that contains the information to be logged.
Users can add their own information to the dict.
All the information in the dict will be logged to wandb.
"""
info_for_logging.update(ctx.info_for_logging)

if hasattr(ctx, "eval_value_min"):
Expand Down
4 changes: 2 additions & 2 deletions ding/model/template/qgpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(self, action, t, condition=None):
return self.qt(ats)


class QGPO_Critic(nn.Module):
class QGPOCritic(nn.Module):
"""
Overview:
QGPO critic network.
Expand Down Expand Up @@ -303,7 +303,7 @@ def __init__(self, cfg: EasyDict) -> None:
output_dim=self.action_dim,
)

self.q = QGPO_Critic(self.device, cfg.qgpo_critic, action_dim=self.action_dim, state_dim=self.obs_dim)
self.q = QGPOCritic(self.device, cfg.qgpo_critic, action_dim=self.action_dim, state_dim=self.obs_dim)

def calculateQ(self, s, a):
"""
Expand Down
15 changes: 10 additions & 5 deletions ding/policy/qgpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,20 @@ def _init_eval(self) -> None:

self.diffusion_steps = self._cfg.eval.diffusion_steps

def _forward_eval(self, data: dict) -> dict:
def _forward_eval(self, data: dict, guidance_scale: float) -> dict:
"""
Overview:
Forward function for eval mode. The eval process is based on the energy-guided policy, \
which is modeled as a diffusion model by parameterizing the score function.
Arguments:
- data (:obj:`dict`): Dict type data.
- guidance_scale (:obj:`float`): The scale of the energy guidance.
Returns:
- output (:obj:`dict`): Dict type data of algorithm output.
"""
guidance_scale = data['guidance_scale']
states = data['s']

data_id = list(states.keys())
states = default_collate(list(states.values()))
data_id = list(data.keys())
states = default_collate(list(data.values()))
actions = self._model.select_actions(
states, diffusion_steps=self.diffusion_steps, guidance_scale=guidance_scale
)
Expand All @@ -226,6 +225,10 @@ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str,
"""
Overview:
Get the train sample from the replay buffer, currently not supported for QGPO.
Arguments:
- transitions (:obj:`List[Dict[str, Any]]`): The data from replay buffer.
Returns:
- samples (:obj:`List[Dict[str, Any]]`): The data for training.
"""
pass

Expand All @@ -252,6 +255,8 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load the state dict.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict.
"""
self._model.load_state_dict(state_dict['model'])
self.behavior_model_optimizer.load_state_dict(state_dict['behavior_model_optimizer'])
Expand Down

0 comments on commit e28b1c5

Please sign in to comment.