From 12c9dd5c932d97fb35f44b516ef4dd774f1cb69c Mon Sep 17 00:00:00 2001 From: luyudong Date: Fri, 14 Jul 2023 10:18:25 +0800 Subject: [PATCH 01/25] Revise old version dt pipline --- ding/policy/decision_transformer.py | 64 +++++++++++++++++-- ding/utils/data/dataset.py | 23 ++++--- .../lunarlander_decision_transformer.py | 4 +- ...t_LunarLander-v2_log_23-07-13-08-50-45.csv | 1 + ...t_LunarLander-v2_log_23-07-13-08-57-28.csv | 2 + ...t_LunarLander-v2_log_23-07-13-09-09-38.csv | 16 +++++ .../offline_data/collect_dqn_data_config.py | 12 ++-- 7 files changed, 95 insertions(+), 27 deletions(-) create mode 100644 dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-50-45.csv create mode 100644 dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-57-28.csv create mode 100644 dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-09-09-38.csv diff --git a/ding/policy/decision_transformer.py b/ding/policy/decision_transformer.py index 44b6731f16..aeb08e1dbf 100644 --- a/ding/policy/decision_transformer.py +++ b/ding/policy/decision_transformer.py @@ -21,11 +21,11 @@ import copy import os import csv -from .dqn import DQNPolicy +from .base_policy import Policy @POLICY_REGISTRY.register('dt') -class DTPolicy(DQNPolicy): +class DTPolicy(Policy): r""" Overview: Policy class of DT algorithm in discrete environments. @@ -356,11 +356,12 @@ def evaluate(self, total_update_times, state_mean=None, state_std=None, render=F return self.max_env_score >= self.stop_value def get_d4rl_normalized_score(self, score, env_name): - env_key = env_name.split('-')[0].lower() - assert env_key in D4RLTrajectoryDataset.REF_MAX_SCORE, \ - f'no reference score for {env_key} env to calculate d4rl score' - d4rl_max_score, d4rl_min_score = D4RLTrajectoryDataset.REF_MAX_SCORE, D4RLTrajectoryDataset.REF_MIN_SCORE - return (score - d4rl_min_score[env_key]) / (d4rl_max_score[env_key] - d4rl_min_score[env_key]) + # env_key = env_name.split('-')[0].lower() + # assert env_key in D4RLTrajectoryDataset.REF_MAX_SCORE, \ + # f'no reference score for {env_key} env to calculate d4rl score' + # d4rl_max_score, d4rl_min_score = D4RLTrajectoryDataset.REF_MAX_SCORE, D4RLTrajectoryDataset.REF_MIN_SCORE + # return (score - d4rl_min_score[env_key]) / (d4rl_max_score[env_key] - d4rl_min_score[env_key]) + return 0 def _state_dict_learn(self) -> Dict[str, Any]: return { @@ -376,3 +377,52 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: def _monitor_vars_learn(self) -> List[str]: return ['cur_lr', 'action_loss'] + + + def _init_eval(self) -> None: + pass + + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + pass + + + def _init_collect(self) -> None: + pass + + def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + pass + + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ + or some continuous transitions(DRQN). + Arguments: + - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ + format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`dict`): The list of training samples. + + .. note:: + We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ + And the user can customize the this data processing procecure by overriding this two methods and collector \ + itself. + """ + pass + + def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: + """ + Overview: + Generate a transition(e.g.: ) for this algorithm training. + Arguments: + - obs (:obj:`Any`): Env observation. + - policy_output (:obj:`Dict[str, Any]`): The output of policy collect mode(``self._forward_collect``),\ + including at least ``action``. + - timestep (:obj:`namedtuple`): The output after env step(execute policy output action), including at \ + least ``obs``, ``reward``, ``done``, (here obs indicates obs after env step). + Returns: + - transition (:obj:`dict`): Dict type transition data. + """ + pass \ No newline at end of file diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index c4905521d1..5d24b071d6 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -342,18 +342,17 @@ def __init__(self, dataset_path: str, context_len: int, rtg_scale: float) -> Non original_keys = ['obs', 'next_obs', 'action', 'reward'] keys = ['observations', 'next_observations', 'actions', 'rewards'] - for key, o_key in zip(keys, original_keys): - trajectories_tmp = [ - { - key: np.stack( - [ - self.trjectories[eps_index][transition_index][o_key] - for transition_index in range(len(self.trajectories[eps_index])) - ], - axis=0 - ) - } for eps_index in range(len(self.trajectories)) - ] + trajectories_tmp = [ + { + key: np.stack( + [ + self.trajectories[eps_index][transition_index][o_key] + for transition_index in range(len(self.trajectories[eps_index])) + ], + axis=0 + ) for key, o_key in zip(keys, original_keys) + } for eps_index in range(len(self.trajectories)) + ] self.trajectories = trajectories_tmp states = [] diff --git a/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py b/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py index 01d5fac4ea..3d9a4c5bf5 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py @@ -27,7 +27,7 @@ embed_dim=128, n_heads=1, dropout_p=0.1, - log_dir='/home/puyuan/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', + log_dir='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', model=dict( state_dim=8, act_dim=4, @@ -41,7 +41,7 @@ discount_factor=0.999, nstep=3, learn=dict( - dataset_path='/home/puyuan/DI-engine/dizoo/box2d/lunarlander/dt_data/data/dqn_data_1000eps.pkl', # TODO + dataset_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO learning_rate=1e-4, target_update_freq=100, kappa=1.0, diff --git a/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-50-45.csv b/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-50-45.csv new file mode 100644 index 0000000000..a03d532276 --- /dev/null +++ b/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-50-45.csv @@ -0,0 +1 @@ +duration,num_updates,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score diff --git a/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-57-28.csv b/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-57-28.csv new file mode 100644 index 0000000000..291f905b5d --- /dev/null +++ b/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-57-28.csv @@ -0,0 +1,2 @@ +duration,num_updates,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score +0:02:06,1000,-210.51245402009764,213.5,0 diff --git a/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-09-09-38.csv b/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-09-09-38.csv new file mode 100644 index 0000000000..2ebd2b172d --- /dev/null +++ b/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-09-09-38.csv @@ -0,0 +1,16 @@ +duration,num_updates,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score +0:02:01,1000,-318.0265498211703,196.6,0 +0:04:20,2000,41.14780414978246,718.4,0 +0:06:42,3000,142.82828392440672,762.4,0 +0:09:02,4000,138.89283069755604,724.3,0 +0:11:23,5000,107.84010045187748,767.2,0 +0:13:44,6000,160.2343642862316,701.2,0 +0:16:00,7000,121.66243934822947,654.6,0 +0:18:22,8000,77.69263487376318,720.2,0 +0:20:37,9000,198.9464703222405,633.8,0 +0:22:59,10000,81.862677775086,728.9,0 +0:25:17,11000,167.60164353189074,671.1,0 +0:27:33,12000,180.91905256798407,634.4,0 +0:29:47,13000,185.41585978196196,563.7,0 +0:32:06,14000,190.83281151114906,653.6,0 +0:34:20,15000,205.16962772579026,561.3,0 diff --git a/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py b/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py index 0beb38fbd5..20050b7340 100644 --- a/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py +++ b/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py @@ -34,15 +34,15 @@ dataloader=dict(num_workers=0, ), log_policy=True, hook=dict( - # load_ckpt_before_run='./lunarlander/ckpt/ckpt_best.pth.tar', - load_ckpt_before_run='/home/puyuan/DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', + load_ckpt_before_run='./ckpt_best.pth.tar', # TODO: syspath modeified in other place, have to use abs path. May be fix in next version. + # load_ckpt_before_run='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', log_show_after_iter=100, save_ckpt_after_iter=10000, save_ckpt_after_run=False, ), cfg_type='BaseLearnerDict', - # load_path='./cartpole/ckpt/ckpt_best.pth.tar', - load_path='/home/puyuan/DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', + load_path='./ckpt_best.pth.tar', # TODO: same like last path. + # load_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', ), update_per_collect=10, batch_size=64, @@ -61,9 +61,9 @@ # save # data_type='hdf5', data_type='naive', - save_path='/home/puyuan/DI-engine/dizoo/box2d/lunarlander/dt_data/data/dqn_data_1000eps.pkl', # TODO(pu) + save_path='./dt_data/dqn_data_1000eps.pkl', # TODO(pu) # load - data_path='/home/puyuan/DI-engine/dizoo/box2d/lunarlander/dt_data/data/dqn_data_10eps.pkl', # TODO(pu) + data_path='./dt_data/dqn_data_10eps.pkl', # TODO(pu) ), # command_mode config other=dict( From 887b587df955eb515c4dcb0c955ff1538941e366 Mon Sep 17 00:00:00 2001 From: luyudong Date: Fri, 14 Jul 2023 11:07:56 +0800 Subject: [PATCH 02/25] Add new dt pipline --- ding/example/dt.py | 42 ++ ding/policy/dt.py | 396 ++++++++++++++++++ .../config/lunarlander_dt_config.py | 0 3 files changed, 438 insertions(+) create mode 100644 ding/example/dt.py create mode 100644 ding/policy/dt.py create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_dt_config.py diff --git a/ding/example/dt.py b/ding/example/dt.py new file mode 100644 index 0000000000..5af78dabd3 --- /dev/null +++ b/ding/example/dt.py @@ -0,0 +1,42 @@ +import gym +from ditk import logging +from ding.model import QAC +from ding.policy import CQLPolicy +from ding.envs import DingEnvWrapper, BaseEnvManagerV2 +from ding.data import create_dataset +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OfflineRLContext +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger +from ding.utils import set_pkg_seed +from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv +from dizoo.classic_control.pendulum.config.pendulum_cql_config import main_config, create_config + + +def main(): + # If you don't have offline data, you need to prepare if first and set the data_path in config + # For demostration, we also can train a RL policy (e.g. SAC) and collect some data + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OfflineRLContext()): + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + dataset = create_dataset(cfg) + model = QAC(**cfg.policy.model) + policy = CQLPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(offline_data_fetcher(cfg, dataset)) + task.use(trainer(cfg, policy.learn_mode)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) + task.use(offline_logger()) + task.run() + + +if __name__ == "__main__": + main() diff --git a/ding/policy/dt.py b/ding/policy/dt.py new file mode 100644 index 0000000000..28c4eeb2f6 --- /dev/null +++ b/ding/policy/dt.py @@ -0,0 +1,396 @@ +"""The code is adapted from https://github.com/nikhilbarhate99/min-decision-transformer +""" + +from typing import List, Dict, Any, Tuple, Union +from collections import namedtuple +from torch.distributions import Normal, Independent +from ding.torch_utils import Adam, to_device +from ditk import logging +from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ + qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data +from ding.model import model_wrap +from ding.utils.data.dataset import D4RLTrajectoryDataset +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_collate, default_decollate +from datetime import datetime +from ding.torch_utils import one_hot +import numpy as np +import torch.nn.functional as F +import torch +import gym +import copy +import os +import csv +from .base_policy import Policy + + +@POLICY_REGISTRY.register('dt') +class DTPolicy(Policy): + r""" + Overview: + Policy class of DT algorithm in discrete environments. + """ + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='dt', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool) Whether the RL algorithm is on-policy or off-policy. + on_policy=False, + # (bool) Whether use priority(priority sample, IS weight, update priority) + priority=False, + # (float) Reward's future discount factor, aka. gamma. + discount_factor=0.97, + # (int) N-step reward for target q_value estimation + nstep=1, + obs_shape=4, + action_shape=2, + # encoder_hidden_size_list=[128, 128, 64], + dataset='medium', # medium / medium-replay / medium-expert + rtg_scale=1000, # normalize returns to go + max_eval_ep_len=1000, # max len of one episode + num_eval_ep=10, # num of evaluation episodes + batch_size=64, # training batch size + wt_decay=1e-4, + warmup_steps=10000, + max_train_iters=200, + context_len=20, + n_blocks=3, + embed_dim=128, + dropout_p=0.1, + learn=dict( + + # batch_size=64, + learning_rate=1e-4, + # ============================================================== + # The following configs are algorithm-specific + # ============================================================== + ), + # collect_mode config + collect=dict(), + eval=dict(), + # other config + other=dict(), + ) + + def default_model(self) -> Tuple[str, List[str]]: + return 'dt', ['ding.model.template.decision_transformer'] + + def _init_learn(self) -> None: + r""" + Overview: + Learn mode init method. Called by ``self.__init__``. + Init the optimizer, algorithm config, main and target models. + """ + + self.stop_value = self._cfg.stop_value + self.env_name = self._cfg.env_name + dataset = self._cfg.dataset # medium / medium-replay / medium-expert + # rtg_scale: scale of `return to go` + # rtg_target: max target of `return to go` + # Our goal is normalize `return to go` to (0, 1), which will favour the covergence. + # As a result, we usually set rtg_scale == rtg_target. + self.rtg_scale = self._cfg.rtg_target # normalize returns to go + self.rtg_target = self._cfg.rtg_target # max target reward_to_go + self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode + self.num_eval_ep = self._cfg.num_eval_ep # num of evaluation episodes + + lr = self._cfg.learn.learning_rate # learning rate + wt_decay = self._cfg.wt_decay # weight decay + warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler + + max_train_iters = self._cfg.max_train_iters + + self.context_len = self._cfg.context_len # K in decision transformer + n_blocks = self._cfg.n_blocks # num of transformer blocks + embed_dim = self._cfg.embed_dim # embedding (hidden) dim of transformer + dropout_p = self._cfg.dropout_p # dropout probability + + # # load data from this file + # dataset_path = f'{self._cfg.dataset_dir}/{env_d4rl_name}.pkl' + + # saves model and csv in this directory + self.log_dir = self._cfg.log_dir + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + + # training and evaluation device + self.device = torch.device(self._device) + + self.start_time = datetime.now().replace(microsecond=0) + self.start_time_str = self.start_time.strftime("%y-%m-%d-%H-%M-%S") + + # prefix = "dt_" + env_d4rl_name + self.prefix = "dt_" + self.env_name + + save_model_name = self.prefix + "_model_" + self.start_time_str + ".pt" + self.save_model_path = os.path.join(self.log_dir, save_model_name) + self.save_best_model_path = self.save_model_path[:-3] + "_best.pt" + + log_csv_name = self.prefix + "_log_" + self.start_time_str + ".csv" + log_csv_path = os.path.join(self.log_dir, log_csv_name) + + self.csv_writer = csv.writer(open(log_csv_path, 'a', 1)) + csv_header = (["duration", "num_updates", "eval_avg_reward", "eval_avg_ep_len", "eval_d4rl_score"]) + + self.csv_writer.writerow(csv_header) + + dataset_path = self._cfg.learn.dataset_path + logging.info("=" * 60) + logging.info("start time: " + self.start_time_str) + logging.info("=" * 60) + + logging.info("device set to: " + str(self.device)) + logging.info("dataset path: " + dataset_path) + logging.info("model save path: " + self.save_model_path) + logging.info("log csv save path: " + log_csv_path) + + self.state_dim = self._cfg.model.state_dim + self.act_dim = self._cfg.model.act_dim + + self._learn_model = self._model + self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) + + self._scheduler = torch.optim.lr_scheduler.LambdaLR( + self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1) + ) + + self.max_env_score = -1.0 + + def _forward_learn(self, data: list) -> Dict[str, Any]: + r""" + Overview: + Forward and backward function of learn mode. + Arguments: + - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] + Returns: + - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + """ + + self._learn_model.train() + + timesteps, states, actions, returns_to_go, traj_mask = data + + timesteps = timesteps.to(self.device) # B x T + states = states.to(self.device) # B x T x state_dim + actions = actions.to(self.device) # B x T x act_dim + returns_to_go = returns_to_go.to(self.device) # B x T x 1 + traj_mask = traj_mask.to(self.device) # B x T + action_target = torch.clone(actions).detach().to(self.device) + + # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), + # and we need a 3-dim tensor + if len(returns_to_go.shape) == 2: + returns_to_go = returns_to_go.unsqueeze(-1) + + # if discrete + if not self._cfg.model.continuous: + actions = one_hot(actions.squeeze(-1), num=self.act_dim) + + state_preds, action_preds, return_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go + ) + + traj_mask = traj_mask.view(-1, ) + + # only consider non padded elements + action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] + + if self._cfg.model.continuous: + action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] + else: + action_target = action_target.view(-1)[traj_mask > 0] + + if self._cfg.model.continuous: + action_loss = F.mse_loss(action_preds, action_target) + else: + action_loss = F.cross_entropy(action_preds, action_target) + + self._optimizer.zero_grad() + action_loss.backward() + torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), 0.25) + self._optimizer.step() + self._scheduler.step() + + return { + 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'], + 'action_loss': action_loss.detach().cpu().item(), + } + + def evaluate_on_env(self, state_mean=None, state_std=None, render=False): + + eval_batch_size = 1 # required for forward pass + + results = {} + total_reward = 0 + total_timesteps = 0 + + # state_dim = env.observation_space.shape[0] + # act_dim = env.action_space.shape[0] + + if state_mean is None: + self.state_mean = torch.zeros((self.state_dim, )).to(self.device) + else: + self.state_mean = torch.from_numpy(state_mean).to(self.device) + + if state_std is None: + self.state_std = torch.ones((self.state_dim, )).to(self.device) + else: + self.state_std = torch.from_numpy(state_std).to(self.device) + + # same as timesteps used for training the transformer + # also, crashes if device is passed to arange() + timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1) + timesteps = timesteps.repeat(eval_batch_size, 1).to(self.device) + + self._learn_model.eval() + + with torch.no_grad(): + + for _ in range(self.num_eval_ep): + + # zeros place holders + # continuous action + actions = torch.zeros( + (eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device + ) + + # discrete action # TODO + # actions = torch.randint(0,self.act_dim,[eval_batch_size, self.max_eval_ep_len, 1], + # dtype=torch.long, device=self.device) + + states = torch.zeros( + (eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device + ) + rewards_to_go = torch.zeros( + (eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device + ) + + # init episode + running_state = self._env.reset() + running_reward = 0 + running_rtg = self.rtg_target / self.rtg_scale + + for t in range(self.max_eval_ep_len): + + total_timesteps += 1 + + # add state in placeholder and normalize + states[0, t] = torch.from_numpy(running_state).to(self.device) + # states[0, t] = (states[0, t].cpu() - self.state_mean.cpu().numpy()) / self.state_std.cpu().numpy() + states[0, t] = (states[0, t] - self.state_mean) / self.state_std + + # calcualate running rtg and add it in placeholder + running_rtg = running_rtg - (running_reward / self.rtg_scale) + rewards_to_go[0, t] = running_rtg + + if t < self.context_len: + _, act_preds, _ = self._learn_model.forward( + timesteps[:, :self.context_len], states[:, :self.context_len], + actions[:, :self.context_len], rewards_to_go[:, :self.context_len] + ) + act = act_preds[0, t].detach() + else: + _, act_preds, _ = self._learn_model.forward( + timesteps[:, t - self.context_len + 1:t + 1], states[:, t - self.context_len + 1:t + 1], + actions[:, t - self.context_len + 1:t + 1], rewards_to_go[:, t - self.context_len + 1:t + 1] + ) + act = act_preds[0, -1].detach() + + # if discrete + if not self._cfg.model.continuous: + act = torch.argmax(act) + running_state, running_reward, done, _ = self._env.step(act.cpu().numpy()) + + # add action in placeholder + actions[0, t] = act + + total_reward += running_reward + + if render: + self._env.render() + if done: + break + + results['eval/avg_reward'] = total_reward / self.num_eval_ep + results['eval/avg_ep_len'] = total_timesteps / self.num_eval_ep + + return results + + def evaluate(self, total_update_times, state_mean=None, state_std=None, render=False): + results = self.evaluate_on_env(state_mean, state_std, render) + + eval_avg_reward = results['eval/avg_reward'] + eval_avg_ep_len = results['eval/avg_ep_len'] + eval_d4rl_score = self.get_d4rl_normalized_score(results['eval/avg_reward'], self.env_name) * 100 + + time_elapsed = str(datetime.now().replace(microsecond=0) - self.start_time) + + log_str = ( + "=" * 60 + '\n' + "time elapsed: " + time_elapsed + '\n' + "num of updates: " + str(total_update_times) + + '\n' + '\n' + "eval avg reward: " + format(eval_avg_reward, ".5f") + '\n' + "eval avg ep len: " + + format(eval_avg_ep_len, ".5f") + '\n' + "eval d4rl score: " + format(eval_d4rl_score, ".5f") + ) + + logging.info(log_str) + + log_data = [time_elapsed, total_update_times, eval_avg_reward, eval_avg_ep_len, eval_d4rl_score] + log_csv_name = self.prefix + "_log_" + self.start_time_str + ".csv" + log_csv_path = os.path.join(self.log_dir, log_csv_name) + + self.csv_writer.writerow(log_data) + + # save model + logging.info("eval_avg_reward: " + format(eval_avg_reward, ".5f")) + eval_env_score = eval_avg_reward + if eval_env_score >= self.max_env_score: + logging.info("saving max env score model at: " + self.save_best_model_path) + torch.save(self._learn_model.state_dict(), self.save_best_model_path) + self.max_env_score = eval_env_score + + logging.info("saving current model at: " + self.save_model_path) + torch.save(self._learn_model.state_dict(), self.save_model_path) + + return self.max_env_score >= self.stop_value + + def get_d4rl_normalized_score(self, score, env_name): + # env_key = env_name.split('-')[0].lower() + # assert env_key in D4RLTrajectoryDataset.REF_MAX_SCORE, \ + # f'no reference score for {env_key} env to calculate d4rl score' + # d4rl_max_score, d4rl_min_score = D4RLTrajectoryDataset.REF_MAX_SCORE, D4RLTrajectoryDataset.REF_MIN_SCORE + # return (score - d4rl_min_score[env_key]) / (d4rl_max_score[env_key] - d4rl_min_score[env_key]) + return 0 + + def _state_dict_learn(self) -> Dict[str, Any]: + return { + 'model': self._learn_model.state_dict(), + # 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + self._learn_model.load_state_dict(state_dict['model']) + # self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def _monitor_vars_learn(self) -> List[str]: + return ['cur_lr', 'action_loss'] + + def _init_eval(self) -> None: + self._env = gym.make(self.env_name) + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + pass + + + def _init_collect(self) -> None: + pass + + def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + pass + + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + pass + + def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: + pass \ No newline at end of file diff --git a/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py b/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py new file mode 100644 index 0000000000..e69de29bb2 From 737b7b6fd468968a3d137504c4930659bcd41a67 Mon Sep 17 00:00:00 2001 From: luyudong Date: Tue, 18 Jul 2023 15:25:33 +0800 Subject: [PATCH 03/25] Add DT in new pipeline --- ding/envs/env_wrappers/env_wrappers.py | 17 ++ ding/example/dt.py | 17 +- ding/policy/__init__.py | 1 + ding/policy/command_mode_policy_instance.py | 3 +- ding/policy/dt.py | 266 ++++++------------ ding/utils/data/dataset.py | 7 +- .../config/lunarlander_dt_config.py | 73 +++++ 7 files changed, 195 insertions(+), 189 deletions(-) diff --git a/ding/envs/env_wrappers/env_wrappers.py b/ding/envs/env_wrappers/env_wrappers.py index 1a75b88179..9a60cb95d5 100644 --- a/ding/envs/env_wrappers/env_wrappers.py +++ b/ding/envs/env_wrappers/env_wrappers.py @@ -1174,6 +1174,23 @@ def reset(self): return self.env.reset() +class AllinObsWrapper(gym.Wrapper): + + def __init__(self, env): + super().__init__(env) + + def reset(self): + return {'obs':self.env.reset(), 'reward': [0]} + + def step(self, action): + obs, reward, done, info = self.env.step(action) + obs = {'obs':obs, 'reward': reward} + from ding.envs import BaseEnvTimestep + return BaseEnvTimestep(obs, reward, done, info) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self.env.seed(seed, dynamic_seed) + def update_shape(obs_shape, act_shape, rew_shape, wrapper_names): """ Overview: diff --git a/ding/example/dt.py b/ding/example/dt.py index 5af78dabd3..1007ca741e 100644 --- a/ding/example/dt.py +++ b/ding/example/dt.py @@ -1,16 +1,17 @@ import gym from ditk import logging -from ding.model import QAC -from ding.policy import CQLPolicy -from ding.envs import DingEnvWrapper, BaseEnvManagerV2 +from ding.model.template.decision_transformer import DecisionTransformer +from ding.policy import DTPolicy +from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2 +from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper from ding.data import create_dataset from ding.config import compile_config from ding.framework import task, ding_init from ding.framework.context import OfflineRLContext from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger from ding.utils import set_pkg_seed -from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv -from dizoo.classic_control.pendulum.config.pendulum_cql_config import main_config, create_config +from dizoo.box2d.lunarlander.envs.lunarlander_env import LunarLanderEnv +from dizoo.box2d.lunarlander.config.lunarlander_dt_config import main_config, create_config def main(): @@ -21,14 +22,14 @@ def main(): ding_init(cfg) with task.start(async_mode=False, ctx=OfflineRLContext()): evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + env_fn=[lambda: AllinObsWrapper(LunarLanderEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) dataset = create_dataset(cfg) - model = QAC(**cfg.policy.model) - policy = CQLPolicy(cfg.policy, model=model) + model = DecisionTransformer(**cfg.policy.model) + policy = DTPolicy(cfg.policy, model=model) task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(offline_data_fetcher(cfg, dataset)) diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 65f3f2757e..9789de56f6 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -12,6 +12,7 @@ from .td3 import TD3Policy from .td3_vae import TD3VAEPolicy from .td3_bc import TD3BCPolicy +from .dt import DTPolicy from .pg import PGPolicy from .a2c import A2CPolicy diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 8b6123c063..05b2423082 100755 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -42,7 +42,8 @@ from .d4pg import D4PGPolicy from .cql import CQLPolicy, CQLDiscretePolicy -from .decision_transformer import DTPolicy +# from .decision_transformer import DTPolicy +from .dt import DTPolicy from .pdqn import PDQNPolicy from .sac import SQILSACPolicy from .madqn import MADQNPolicy diff --git a/ding/policy/dt.py b/ding/policy/dt.py index 28c4eeb2f6..8920fe8514 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -109,42 +109,9 @@ def _init_learn(self) -> None: # # load data from this file # dataset_path = f'{self._cfg.dataset_dir}/{env_d4rl_name}.pkl' - # saves model and csv in this directory - self.log_dir = self._cfg.log_dir - if not os.path.exists(self.log_dir): - os.makedirs(self.log_dir) - # training and evaluation device self.device = torch.device(self._device) - self.start_time = datetime.now().replace(microsecond=0) - self.start_time_str = self.start_time.strftime("%y-%m-%d-%H-%M-%S") - - # prefix = "dt_" + env_d4rl_name - self.prefix = "dt_" + self.env_name - - save_model_name = self.prefix + "_model_" + self.start_time_str + ".pt" - self.save_model_path = os.path.join(self.log_dir, save_model_name) - self.save_best_model_path = self.save_model_path[:-3] + "_best.pt" - - log_csv_name = self.prefix + "_log_" + self.start_time_str + ".csv" - log_csv_path = os.path.join(self.log_dir, log_csv_name) - - self.csv_writer = csv.writer(open(log_csv_path, 'a', 1)) - csv_header = (["duration", "num_updates", "eval_avg_reward", "eval_avg_ep_len", "eval_d4rl_score"]) - - self.csv_writer.writerow(csv_header) - - dataset_path = self._cfg.learn.dataset_path - logging.info("=" * 60) - logging.info("start time: " + self.start_time_str) - logging.info("=" * 60) - - logging.info("device set to: " + str(self.device)) - logging.info("dataset path: " + dataset_path) - logging.info("model save path: " + self.save_model_path) - logging.info("log csv save path: " + log_csv_path) - self.state_dim = self._cfg.model.state_dim self.act_dim = self._cfg.model.act_dim @@ -169,13 +136,14 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: self._learn_model.train() + data = [[i[j] for i in data] for j in range(len(data[0]))] timesteps, states, actions, returns_to_go, traj_mask = data - timesteps = timesteps.to(self.device) # B x T - states = states.to(self.device) # B x T x state_dim - actions = actions.to(self.device) # B x T x act_dim - returns_to_go = returns_to_go.to(self.device) # B x T x 1 - traj_mask = traj_mask.to(self.device) # B x T + timesteps = torch.stack(timesteps).to(self.device) # B x T + states = torch.stack(states).to(self.device) # B x T x state_dim + actions = torch.stack(actions).to(self.device) # B x T x act_dim + returns_to_go = torch.stack(returns_to_go).to(self.device) # B x T x 1 + traj_mask = torch.stack(traj_mask).to(self.device) # B x T action_target = torch.clone(actions).detach().to(self.device) # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), @@ -215,143 +183,94 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: return { 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'], 'action_loss': action_loss.detach().cpu().item(), + 'total_loss': action_loss.detach().cpu().item(), } - def evaluate_on_env(self, state_mean=None, state_std=None, render=False): - - eval_batch_size = 1 # required for forward pass - - results = {} - total_reward = 0 - total_timesteps = 0 - - # state_dim = env.observation_space.shape[0] - # act_dim = env.action_space.shape[0] - - if state_mean is None: - self.state_mean = torch.zeros((self.state_dim, )).to(self.device) - else: - self.state_mean = torch.from_numpy(state_mean).to(self.device) - - if state_std is None: - self.state_std = torch.ones((self.state_dim, )).to(self.device) - else: - self.state_std = torch.from_numpy(state_std).to(self.device) - - # same as timesteps used for training the transformer - # also, crashes if device is passed to arange() - timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1) - timesteps = timesteps.repeat(eval_batch_size, 1).to(self.device) + def _init_eval(self) -> None: + r""" + Overview: + Evaluate mode init method. Called by ``self.__init__``, initialize eval_model. + """ + self._eval_model = self._model + # self._eval_model.reset() + # init data + self.device = torch.device(self._device) + self.state_dim = self._cfg.model.state_dim + self.act_dim = self._cfg.model.act_dim + self.eval_batch_size = self._cfg.evaluator_env_num + self.max_eval_ep_len = self._cfg.max_eval_ep_len + self.context_len = self._cfg.context_len # K in decision transformer + self.rtg_scale = self._cfg.rtg_target # normalize returns to go + self.rtg_target = self._cfg.rtg_target # max target reward_to_go - self._learn_model.eval() + self.running_rtg = [self.rtg_target / self.rtg_scale] * self.eval_batch_size + self.t = [0] * self.eval_batch_size + self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device) + #self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + # save and forward + data_id = list(data.keys()) + + self._eval_model.eval() with torch.no_grad(): - - for _ in range(self.num_eval_ep): - - # zeros place holders - # continuous action - actions = torch.zeros( - (eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device - ) - - # discrete action # TODO - # actions = torch.randint(0,self.act_dim,[eval_batch_size, self.max_eval_ep_len, 1], - # dtype=torch.long, device=self.device) - - states = torch.zeros( - (eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device - ) - rewards_to_go = torch.zeros( - (eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device - ) - - # init episode - running_state = self._env.reset() - running_reward = 0 - running_rtg = self.rtg_target / self.rtg_scale - - for t in range(self.max_eval_ep_len): - - total_timesteps += 1 - - # add state in placeholder and normalize - states[0, t] = torch.from_numpy(running_state).to(self.device) - # states[0, t] = (states[0, t].cpu() - self.state_mean.cpu().numpy()) / self.state_std.cpu().numpy() - states[0, t] = (states[0, t] - self.state_mean) / self.state_std - - # calcualate running rtg and add it in placeholder - running_rtg = running_rtg - (running_reward / self.rtg_scale) - rewards_to_go[0, t] = running_rtg - - if t < self.context_len: - _, act_preds, _ = self._learn_model.forward( - timesteps[:, :self.context_len], states[:, :self.context_len], - actions[:, :self.context_len], rewards_to_go[:, :self.context_len] - ) - act = act_preds[0, t].detach() - else: - _, act_preds, _ = self._learn_model.forward( - timesteps[:, t - self.context_len + 1:t + 1], states[:, t - self.context_len + 1:t + 1], - actions[:, t - self.context_len + 1:t + 1], rewards_to_go[:, t - self.context_len + 1:t + 1] - ) - act = act_preds[0, -1].detach() - - # if discrete - if not self._cfg.model.continuous: - act = torch.argmax(act) - running_state, running_reward, done, _ = self._env.step(act.cpu().numpy()) - - # add action in placeholder - actions[0, t] = act - - total_reward += running_reward - - if render: - self._env.render() - if done: - break - - results['eval/avg_reward'] = total_reward / self.num_eval_ep - results['eval/avg_ep_len'] = total_timesteps / self.num_eval_ep - - return results - - def evaluate(self, total_update_times, state_mean=None, state_std=None, render=False): - results = self.evaluate_on_env(state_mean, state_std, render) - - eval_avg_reward = results['eval/avg_reward'] - eval_avg_ep_len = results['eval/avg_ep_len'] - eval_d4rl_score = self.get_d4rl_normalized_score(results['eval/avg_reward'], self.env_name) * 100 - - time_elapsed = str(datetime.now().replace(microsecond=0) - self.start_time) - - log_str = ( - "=" * 60 + '\n' + "time elapsed: " + time_elapsed + '\n' + "num of updates: " + str(total_update_times) + - '\n' + '\n' + "eval avg reward: " + format(eval_avg_reward, ".5f") + '\n' + "eval avg ep len: " + - format(eval_avg_ep_len, ".5f") + '\n' + "eval d4rl score: " + format(eval_d4rl_score, ".5f") - ) - - logging.info(log_str) - - log_data = [time_elapsed, total_update_times, eval_avg_reward, eval_avg_ep_len, eval_d4rl_score] - log_csv_name = self.prefix + "_log_" + self.start_time_str + ".csv" - log_csv_path = os.path.join(self.log_dir, log_csv_name) - - self.csv_writer.writerow(log_data) - - # save model - logging.info("eval_avg_reward: " + format(eval_avg_reward, ".5f")) - eval_env_score = eval_avg_reward - if eval_env_score >= self.max_env_score: - logging.info("saving max env score model at: " + self.save_best_model_path) - torch.save(self._learn_model.state_dict(), self.save_best_model_path) - self.max_env_score = eval_env_score - - logging.info("saving current model at: " + self.save_model_path) - torch.save(self._learn_model.state_dict(), self.save_model_path) - - return self.max_env_score >= self.stop_value + timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self.device) + actions = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self.device) + states = torch.zeros((self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self.device) + rewards_to_go = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self.device) + for i in data_id: + self.states[i, self.t] = data[i]['obs'] + self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale) + self.rewards_to_go[i, self.t] = self.running_rtg[i] + + if self.t[i] < self.context_len: + timesteps[i] = self.timesteps[i, :self.context_len] + states[i] = self.states[i, :self.context_len] + actions[i] = self.actions[i, :self.context_len] + rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] + else: + timesteps[i] = [self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1]] + states[i] = [self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1]] + actions[i] = [self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1]] + rewards_to_go[i] = [self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1]] + if not self._cfg.model.continuous: + actions = one_hot(actions.squeeze(-1), num=self.act_dim) + _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go) + del timesteps, states, actions, rewards_to_go + act = torch.zeros((self.eval_batch_size, self.act_dim), dtype=torch.long, device=self.device) + for i in data_id: + act[i] = act_preds[i, self.t[i]].detach() if self.t[i] < self.context_len else act_preds[i, -1].detach() + if not self._cfg.model.continuous: + act = torch.argmax(act, axis=1) + self.actions[:, self.t] = act.unsqueeze(1) + if self._cuda: + act = to_device(act, 'cpu') + output = {'action': act} + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _reset_eval(self, data_id: List[int] = None) -> None: + # clean data + if data_id is None: + self.running_rtg = [self.rtg_target / self.rtg_scale] * self.eval_batch_size + self.t = [0] * self.eval_batch_size + self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device) + #self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) + else: + for i in data_id: + self.running_rtg[i] = self.rtg_target / self.rtg_scale + self.t[i] = 0 + self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self.device) + #self.actions[i] = torch.zeros((self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) + self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) + self.states[i] = torch.zeros((self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) def get_d4rl_normalized_score(self, score, env_name): # env_key = env_name.split('-')[0].lower() @@ -376,13 +295,6 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: def _monitor_vars_learn(self) -> List[str]: return ['cur_lr', 'action_loss'] - def _init_eval(self) -> None: - self._env = gym.make(self.env_name) - - def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: - pass - - def _init_collect(self) -> None: pass diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 5d24b071d6..16c45f36a5 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -328,9 +328,10 @@ class D4RLTrajectoryDataset(Dataset): }, } - def __init__(self, dataset_path: str, context_len: int, rtg_scale: float) -> None: - - self.context_len = context_len + def __init__(self, cfg: dict) -> None: + dataset_path = cfg.policy.collect.get('data_path', None) + rtg_scale = cfg.policy.rtg_scale + self.context_len = cfg.policy.context_len # load dataset with open(dataset_path, 'rb') as f: diff --git a/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py b/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py index e69de29bb2..d225b299ec 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py @@ -0,0 +1,73 @@ +from easydict import EasyDict +import torch +from copy import deepcopy + +lunarlander_dt_config = dict( + exp_name='data_dt/lunarlander_dt_1000eps_rtgt300_meel1000_seed0_debug', + env=dict( + env_id='LunarLander-v2', + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=200, + ), + policy=dict( + stop_value=200, + device='cuda', + env_name='LunarLander-v2', + rtg_target=300, # max target reward_to_go + max_eval_ep_len=1000, # max len of one episode # TODO + num_eval_ep=10, # num of evaluation episodes + wt_decay=1e-4, + warmup_steps=10000, + num_updates_per_iter=100, + context_len=20, # TODO + evaluator_env_num=8, + n_blocks=3, + embed_dim=128, + n_heads=1, + dropout_p=0.1, + model=dict( + state_dim=8, + act_dim=4, + n_blocks=3, + h_dim=128, + context_len=20, + n_heads=1, + drop_p=0.1, + continuous=False, # TODO + ), + discount_factor=0.999, + nstep=3, + learn=dict( + learning_rate=1e-4, + batch_size=64, # training batch size + target_update_freq=100, + kappa=1.0, + min_q_weight=4.0, + ), + collect=dict( + data_type='d4rl_trajectory', + data_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=100, )), + ), +) +lunarlander_dt_config = EasyDict(lunarlander_dt_config) +main_config = lunarlander_dt_config +lunarlander_dt_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='dt'), +) +lunarlander_dt_create_config = EasyDict(lunarlander_dt_create_config) +create_config = lunarlander_dt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_dt, collect_demo_data, eval, serial_pipeline + config = deepcopy([main_config, create_config]) + serial_pipeline_dt(config, seed=0, max_train_iter=1000) From fce01fc666d26e9c2175d44bbd15bc25e66c87c7 Mon Sep 17 00:00:00 2001 From: luyudong Date: Tue, 25 Jul 2023 17:09:23 +0800 Subject: [PATCH 04/25] Add img input to support atari --- ding/example/dt.py | 4 +- ding/example/dt_atari.py | 49 ++ ding/example/dt_mujoco.py | 49 ++ ding/framework/context.py | 1 + .../framework/middleware/functional/logger.py | 7 +- ding/model/template/decision_transformer.py | 24 +- ding/model/template/dt.py | 229 +++++++++ ding/policy/dt.py | 162 ++++--- ding/torch_utils/network/transformer.py | 50 ++ ding/utils/data/dataset.py | 457 ++++++++++++++---- dizoo/atari/config/pong_dt_config.py | 103 ++++ .../config/lunarlander_dt_config.py | 8 +- ...t_LunarLander-v2_log_23-07-13-08-50-45.csv | 1 - ...t_LunarLander-v2_log_23-07-13-08-57-28.csv | 2 - ...t_LunarLander-v2_log_23-07-13-09-09-38.csv | 16 - dizoo/d4rl/config/__init__.py | 6 +- dizoo/d4rl/config/hopper_expert_dt_config.py | 15 +- dizoo/d4rl/config/hopper_medium_dt_config.py | 28 +- .../config/hopper_medium_expert_dt_config.py | 15 +- 19 files changed, 1031 insertions(+), 195 deletions(-) create mode 100644 ding/example/dt_atari.py create mode 100644 ding/example/dt_mujoco.py create mode 100644 ding/model/template/dt.py create mode 100644 dizoo/atari/config/pong_dt_config.py delete mode 100644 dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-50-45.csv delete mode 100644 dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-57-28.csv delete mode 100644 dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-09-09-38.csv diff --git a/ding/example/dt.py b/ding/example/dt.py index 1007ca741e..30969ce5da 100644 --- a/ding/example/dt.py +++ b/ding/example/dt.py @@ -8,7 +8,7 @@ from ding.config import compile_config from ding.framework import task, ding_init from ding.framework.context import OfflineRLContext -from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger, termination_checker, final_ctx_saver from ding.utils import set_pkg_seed from dizoo.box2d.lunarlander.envs.lunarlander_env import LunarLanderEnv from dizoo.box2d.lunarlander.config.lunarlander_dt_config import main_config, create_config @@ -28,12 +28,14 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) dataset = create_dataset(cfg) + cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats() model = DecisionTransformer(**cfg.policy.model) policy = DTPolicy(cfg.policy, model=model) task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(offline_data_fetcher(cfg, dataset)) task.use(trainer(cfg, policy.learn_mode)) + task.use(termination_checker(max_train_iter=1e5)) task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.use(offline_logger()) task.run() diff --git a/ding/example/dt_atari.py b/ding/example/dt_atari.py new file mode 100644 index 0000000000..88e66f9eb1 --- /dev/null +++ b/ding/example/dt_atari.py @@ -0,0 +1,49 @@ +import gym +import torch +import numpy as np +from ditk import logging +from ding.model.template.dt import DecisionTransformer, DecisionTransformerA +from ding.policy import DTPolicy +from ding.envs import BaseEnvManagerV2 +from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper +from ding.data import create_dataset +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OfflineRLContext +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger, termination_checker +from ding.utils import set_pkg_seed +from dizoo.atari.envs import AtariEnv +from dizoo.atari.config.pong_dt_config import main_config, create_config + + +def main(): + # If you don't have offline data, you need to prepare if first and set the data_path in config + # For demostration, we also can train a RL policy (e.g. SAC) and collect some data + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + # ding_init(cfg) + with task.start(async_mode=False, ctx=OfflineRLContext()): + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: AllinObsWrapper(AtariEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + dataset = create_dataset(cfg) + cfg.policy.model.max_timestep = dataset.get_max_timestep() + # model = DecisionTransformer(**cfg.policy.model) + model = DecisionTransformerA(cfg.policy.model) + policy = DTPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(offline_data_fetcher(cfg, dataset)) + task.use(trainer(cfg, policy.learn_mode)) + task.use(termination_checker(max_train_iter=1e5)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) + task.use(offline_logger(cfg.exp_name)) + task.run() + + +if __name__ == "__main__": + main() + diff --git a/ding/example/dt_mujoco.py b/ding/example/dt_mujoco.py new file mode 100644 index 0000000000..9f176b353b --- /dev/null +++ b/ding/example/dt_mujoco.py @@ -0,0 +1,49 @@ +import gym +import torch +import numpy as np +from ditk import logging +from ding.model.template.decision_transformer import DecisionTransformer +from ding.policy import DTPolicy +from ding.envs import BaseEnvManagerV2 +from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper +from ding.data import create_dataset +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OfflineRLContext +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger, termination_checker +from ding.utils import set_pkg_seed +from dizoo.d4rl.envs import D4RLEnv +from dizoo.d4rl.config.hopper_medium_dt_config import main_config, create_config + + +def main(): + # If you don't have offline data, you need to prepare if first and set the data_path in config + # For demostration, we also can train a RL policy (e.g. SAC) and collect some data + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + # ding_init(cfg) + with task.start(async_mode=False, ctx=OfflineRLContext()): + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: AllinObsWrapper(D4RLEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + dataset = create_dataset(cfg) + # env_data_stats = dataset.get_d4rl_dataset_stats(cfg.policy.dataset_name) + env_data_stats = dataset.get_state_stats() + cfg.policy.state_mean, cfg.policy.state_std = np.array(env_data_stats['state_mean']), np.array(env_data_stats['state_std']) + model = DecisionTransformer(**cfg.policy.model) + policy = DTPolicy(cfg.policy, model=model) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(offline_data_fetcher(cfg, dataset)) + task.use(trainer(cfg, policy.learn_mode)) + task.use(termination_checker(max_train_iter=1e5)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) + task.use(offline_logger(cfg.exp_name)) + task.run() + + +if __name__ == "__main__": + main() + diff --git a/ding/framework/context.py b/ding/framework/context.py index 886d4933f6..8ef8f764fd 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -82,6 +82,7 @@ class OfflineRLContext(Context): # common total_step: int = 0 + env_step: int = 0 train_epoch: int = 0 train_iter: int = 0 train_data: Union[Dict, List] = None diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index add0e26492..dfecfb7d22 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -11,6 +11,7 @@ import wandb import pickle import treetensor.numpy as tnp +from tensorboardX import SummaryWriter from ding.framework import task from ding.envs import BaseEnvManagerV2 from ding.utils import DistributedWriter @@ -92,10 +93,12 @@ def _logger(ctx: "OnlineRLContext"): return _logger -def offline_logger() -> Callable: +def offline_logger( + exp_name: str = None +) -> Callable: if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() - writer = DistributedWriter.get_instance() + writer = SummaryWriter(logdir = exp_name) def _logger(ctx: "OfflineRLContext"): if task.finish: diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py index ef86657f86..7efbf79890 100644 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -4,11 +4,33 @@ from ding.utils import MODEL_REGISTRY from typing import Tuple -from ding.torch_utils.network.transformer import Attention +from ding.torch_utils.network.transformer import Attention, MaskedCausalAttention import torch import torch.nn as nn +class BlockM(nn.Module): + def __init__(self, h_dim, max_T, n_heads, drop_p): + super().__init__() + self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) + self.mlp = nn.Sequential( + nn.Linear(h_dim, 4*h_dim), + nn.GELU(), + nn.Linear(4*h_dim, h_dim), + nn.Dropout(drop_p), + ) + self.ln1 = nn.LayerNorm(h_dim) + self.ln2 = nn.LayerNorm(h_dim) + + def forward(self, x): + # Attention -> LayerNorm -> MLP -> LayerNorm + x = x + self.attention(x) # residual + x = self.ln1(x) + x = x + self.mlp(x) # residual + x = self.ln2(x) + return x + + class Block(nn.Module): def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py new file mode 100644 index 0000000000..d6d39e74fd --- /dev/null +++ b/ding/model/template/dt.py @@ -0,0 +1,229 @@ +""" +this extremely minimal Decision Transformer model is based on +the following causal transformer (GPT) implementation: + +Misha Laskin's tweet: +https://twitter.com/MishaLaskin/status/1481767788775628801?cxt=HHwWgoCzmYD9pZApAAAA + +and its corresponding notebook: +https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing + +** the above colab notebook has a bug while applying masked_fill +which is fixed in the following code +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MaskedCausalAttention(nn.Module): + def __init__(self, h_dim, max_T, n_heads, drop_p): + super().__init__() + + self.n_heads = n_heads + self.max_T = max_T + + self.q_net = nn.Linear(h_dim, h_dim) + self.k_net = nn.Linear(h_dim, h_dim) + self.v_net = nn.Linear(h_dim, h_dim) + + self.proj_net = nn.Linear(h_dim, h_dim) + + self.att_drop = nn.Dropout(drop_p) + self.proj_drop = nn.Dropout(drop_p) + + ones = torch.ones((max_T, max_T)) + mask = torch.tril(ones).view(1, 1, max_T, max_T) + + # register buffer makes sure mask does not get updated + # during backpropagation + self.register_buffer('mask',mask) + + def forward(self, x): + B, T, C = x.shape # batch size, seq length, h_dim * n_heads + + N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim + + # rearrange q, k, v as (B, N, T, D) + q = self.q_net(x).view(B, T, N, D).transpose(1,2) + k = self.k_net(x).view(B, T, N, D).transpose(1,2) + v = self.v_net(x).view(B, T, N, D).transpose(1,2) + + # weights (B, N, T, T) + weights = q @ k.transpose(2,3) / math.sqrt(D) + # causal mask applied to weights + weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf')) + # normalize weights, all -inf -> 0 after softmax + normalized_weights = F.softmax(weights, dim=-1) + + # attention (B, N, T, D) + attention = self.att_drop(normalized_weights @ v) + + # gather heads and project (B, N, T, D) -> (B, T, N*D) + attention = attention.transpose(1, 2).contiguous().view(B,T,N*D) + + out = self.proj_drop(self.proj_net(attention)) + return out + + +class Block(nn.Module): + def __init__(self, h_dim, max_T, n_heads, drop_p): + super().__init__() + self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) + self.mlp = nn.Sequential( + nn.Linear(h_dim, 4*h_dim), + nn.GELU(), + nn.Linear(4*h_dim, h_dim), + nn.Dropout(drop_p), + ) + self.ln1 = nn.LayerNorm(h_dim) + self.ln2 = nn.LayerNorm(h_dim) + + def forward(self, x): + # Attention -> LayerNorm -> MLP -> LayerNorm + x = x + self.attention(x) # residual + x = self.ln1(x) + x = x + self.mlp(x) # residual + x = self.ln2(x) + return x + + +class DecisionTransformer(nn.Module): + def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len, + n_heads, drop_p, max_timestep=4096): + super().__init__() + + self.state_dim = state_dim + self.act_dim = act_dim + self.h_dim = h_dim + + ### transformer blocks + input_seq_len = 3 * context_len + blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] + self.transformer = nn.Sequential(*blocks) + + ### projection heads (project to embedding) + self.embed_ln = nn.LayerNorm(h_dim) + self.embed_timestep = nn.Embedding(max_timestep, h_dim) + self.embed_rtg = torch.nn.Linear(1, h_dim) + self.embed_state = torch.nn.Linear(state_dim, h_dim) + + # # discrete actions + # self.embed_action = torch.nn.Embedding(act_dim, h_dim) + # use_action_tanh = False # False for discrete actions + + # continuous actions + self.embed_action = torch.nn.Linear(act_dim, h_dim) + use_action_tanh = True # True for continuous actions + + ### prediction heads + self.predict_rtg = torch.nn.Linear(h_dim, 1) + self.predict_state = torch.nn.Linear(h_dim, state_dim) + self.predict_action = nn.Sequential( + *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) + ) + + + def forward(self, timesteps, states, actions, returns_to_go): + + B, T, _ = states.shape + + time_embeddings = self.embed_timestep(timesteps) + + # time embeddings are treated similar to positional embeddings + state_embeddings = self.embed_state(states) + time_embeddings + action_embeddings = self.embed_action(actions) + time_embeddings + returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings + + # stack rtg, states and actions and reshape sequence as + # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) + h = torch.stack( + (returns_embeddings, state_embeddings, action_embeddings), dim=1 + ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) + + h = self.embed_ln(h) + + # transformer and prediction + h = self.transformer(h) + + # get h reshaped such that its size = (B x 3 x T x h_dim) and + # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t + # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t + # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t + # that is, for each timestep (t) we have 3 output embeddings from the transformer, + # each conditioned on all previous timesteps plus + # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. + h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) + + # get predictions + return_preds = self.predict_rtg(h[:,2]) # predict next rtg given r, s, a + state_preds = self.predict_state(h[:,2]) # predict next state given r, s, a + action_preds = self.predict_action(h[:,1]) # predict action given r, s + + return state_preds, action_preds, return_preds + + +class DecisionTransformerA(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_embd = config.n_embd + self.block_size = config.context_len * 3 + h_dim = config.h_dim + + # input embedding stem + self.tok_emb = nn.Embedding(config.act_dim, config.n_embd) + # self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.pos_emb = nn.Parameter(torch.zeros(1, self.block_size + 1, config.n_embd)) + self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep+1, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(h_dim, self.block_size, config.n_heads, config.drop_p) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.act_dim, bias=False) + + self.state_encoder = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), + nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), + nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), + nn.Flatten(), nn.Linear(3136, config.n_embd), nn.Tanh()) + + self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh()) + + self.action_embeddings = nn.Sequential(nn.Embedding(config.act_dim, config.n_embd), nn.Tanh()) + + + # state, action, and return + def forward(self, timesteps, states, actions, returns_to_go): + # states: (batch, block_size, 4*84*84) + # actions: (batch, block_size, 1) + # rtgs: (batch, block_size, 1) + # timesteps: (batch, 1, 1) + rtgs = returns_to_go + state_embeddings = self.state_encoder(states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()) # (batch * block_size, n_embd) + state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd) # (batch, block_size, n_embd) + + rtg_embeddings = self.ret_emb(rtgs.type(torch.float32)) + action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd) + + token_embeddings = torch.zeros((states.shape[0], states.shape[1]*3 - 1, self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) + token_embeddings[:,::3,:] = rtg_embeddings + token_embeddings[:,1::3,:] = state_embeddings + token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] + 1:,:] + + batch_size = states.shape[0] + all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd + + position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :] + + x = self.drop(token_embeddings + position_embeddings) + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + + logits = logits[:, 1::3, :] # only keep predictions from state_embeddings + + return None, logits, None \ No newline at end of file diff --git a/ding/policy/dt.py b/ding/policy/dt.py index 8920fe8514..2944e9b93f 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -58,8 +58,10 @@ class DTPolicy(Policy): n_blocks=3, embed_dim=128, dropout_p=0.1, + log_dir='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', learn=dict( + dataset_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO # batch_size=64, learning_rate=1e-4, # ============================================================== @@ -82,15 +84,12 @@ def _init_learn(self) -> None: Learn mode init method. Called by ``self.__init__``. Init the optimizer, algorithm config, main and target models. """ - - self.stop_value = self._cfg.stop_value self.env_name = self._cfg.env_name - dataset = self._cfg.dataset # medium / medium-replay / medium-expert # rtg_scale: scale of `return to go` # rtg_target: max target of `return to go` # Our goal is normalize `return to go` to (0, 1), which will favour the covergence. # As a result, we usually set rtg_scale == rtg_target. - self.rtg_scale = self._cfg.rtg_target # normalize returns to go + self.rtg_scale = self._cfg.rtg_scale # normalize returns to go self.rtg_target = self._cfg.rtg_target # max target reward_to_go self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode self.num_eval_ep = self._cfg.num_eval_ep # num of evaluation episodes @@ -99,12 +98,7 @@ def _init_learn(self) -> None: wt_decay = self._cfg.wt_decay # weight decay warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler - max_train_iters = self._cfg.max_train_iters - self.context_len = self._cfg.context_len # K in decision transformer - n_blocks = self._cfg.n_blocks # num of transformer blocks - embed_dim = self._cfg.embed_dim # embedding (hidden) dim of transformer - dropout_p = self._cfg.dropout_p # dropout probability # # load data from this file # dataset_path = f'{self._cfg.dataset_dir}/{env_d4rl_name}.pkl' @@ -138,7 +132,6 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: data = [[i[j] for i in data] for j in range(len(data[0]))] timesteps, states, actions, returns_to_go, traj_mask = data - timesteps = torch.stack(timesteps).to(self.device) # B x T states = torch.stack(states).to(self.device) # B x T x state_dim actions = torch.stack(actions).to(self.device) # B x T x act_dim @@ -152,27 +145,27 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: returns_to_go = returns_to_go.unsqueeze(-1) # if discrete - if not self._cfg.model.continuous: + if not self._cfg.model.continuous and self.cfg.env_type != 'atari': actions = one_hot(actions.squeeze(-1), num=self.act_dim) state_preds, action_preds, return_preds = self._learn_model.forward( timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go ) - - traj_mask = traj_mask.view(-1, ) - - # only consider non padded elements - action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] - - if self._cfg.model.continuous: - action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] + + if self.cfg.env_type == 'atari': + action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) else: - action_target = action_target.view(-1)[traj_mask > 0] + traj_mask = traj_mask.view(-1, ) - if self._cfg.model.continuous: - action_loss = F.mse_loss(action_preds, action_target) - else: - action_loss = F.cross_entropy(action_preds, action_target) + # only consider non padded elements + action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] + + if self._cfg.model.continuous: + action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] + action_loss = F.mse_loss(action_preds, action_target) + else: + action_target = action_target.view(-1)[traj_mask > 0] + action_loss = F.cross_entropy(action_preds, action_target) self._optimizer.zero_grad() action_loss.backward() @@ -195,57 +188,84 @@ def _init_eval(self) -> None: # self._eval_model.reset() # init data self.device = torch.device(self._device) + self.rtg_scale = self._cfg.rtg_scale # normalize returns to go + self.rtg_target = self._cfg.rtg_target # max target reward_to_go self.state_dim = self._cfg.model.state_dim self.act_dim = self._cfg.model.act_dim self.eval_batch_size = self._cfg.evaluator_env_num self.max_eval_ep_len = self._cfg.max_eval_ep_len self.context_len = self._cfg.context_len # K in decision transformer - self.rtg_scale = self._cfg.rtg_target # normalize returns to go - self.rtg_target = self._cfg.rtg_target # max target reward_to_go - - self.running_rtg = [self.rtg_target / self.rtg_scale] * self.eval_batch_size - self.t = [0] * self.eval_batch_size + + self.t = [0 for _ in range(self.eval_batch_size)] + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device) - #self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + if not self._cfg.model.continuous: + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) + else: + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) + if self.cfg.env_type == 'atari': + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + else: + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + self.state_mean = torch.from_numpy(self._cfg.state_mean).to(self.device) + self.state_std = torch.from_numpy(self._cfg.state_std).to(self.device) self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: # save and forward data_id = list(data.keys()) + data_len = len(data_id) self._eval_model.eval() with torch.no_grad(): - timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self.device) - actions = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self.device) - states = torch.zeros((self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self.device) - rewards_to_go = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self.device) + timesteps = torch.zeros((data_len, 1, 1), dtype=torch.long, device=self.device) + if not self._cfg.model.continuous: + actions = torch.zeros((data_len, self.context_len, 1), dtype=torch.long, device=self.device) + else: + actions = torch.zeros((data_len, self.context_len, self.act_dim), dtype=torch.float32, device=self.device) + if self.cfg.env_type == 'atari': + states = torch.zeros((data_len, self.context_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + else: + states = torch.zeros((data_len, self.context_len, self.state_dim), dtype=torch.float32, device=self.device) + rewards_to_go = torch.zeros((data_len, self.context_len, 1), dtype=torch.float32, device=self.device) for i in data_id: - self.states[i, self.t] = data[i]['obs'] - self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale) - self.rewards_to_go[i, self.t] = self.running_rtg[i] + if self.cfg.env_type == 'atari': + self.states[i, self.t[i]] = data[i]['obs'].to(self.device) / 255 + else: + self.states[i, self.t[i]] = (data[i]['obs'].to(self.device) - self.state_mean) / self.state_std + # self.states[i, self.t[i]] = torch.tensor(data[i]['obs']) + self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'].to(self.device) / self.rtg_scale) + # self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'][0] / self.rtg_scale) + self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] - if self.t[i] < self.context_len: - timesteps[i] = self.timesteps[i, :self.context_len] + if self.t[i] <= self.context_len: + if self.cfg.env_type == 'atari': + timesteps[i] = self.t[i] * torch.ones((1), dtype=torch.int64).to(self.device) + else: + timesteps[i] = self.timesteps[i, :self.context_len] states[i] = self.states[i, :self.context_len] actions[i] = self.actions[i, :self.context_len] rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] else: - timesteps[i] = [self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1]] - states[i] = [self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1]] - actions[i] = [self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1]] - rewards_to_go[i] = [self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1]] - if not self._cfg.model.continuous: - actions = one_hot(actions.squeeze(-1), num=self.act_dim) + if self.cfg.env_type == 'atari': + timesteps[i] = self.t[i] * torch.ones((1), dtype=torch.int64).to(self.device) + else: + timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + # if not self._cfg.model.continuous: + # actions = one_hot(actions.squeeze(-1), num=self.act_dim) _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go) del timesteps, states, actions, rewards_to_go - act = torch.zeros((self.eval_batch_size, self.act_dim), dtype=torch.long, device=self.device) + act = torch.zeros((self.eval_batch_size, self.act_dim), dtype=torch.float32, device=self.device) for i in data_id: act[i] = act_preds[i, self.t[i]].detach() if self.t[i] < self.context_len else act_preds[i, -1].detach() if not self._cfg.model.continuous: - act = torch.argmax(act, axis=1) - self.actions[:, self.t] = act.unsqueeze(1) + act = torch.argmax(act, axis=1).unsqueeze(1) + for i in data_id: + self.actions[i, self.t[i]] = act[i] + self.t[i] += 1 if self._cuda: act = to_device(act, 'cpu') output = {'action': act} @@ -255,30 +275,39 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: def _reset_eval(self, data_id: List[int] = None) -> None: # clean data if data_id is None: - self.running_rtg = [self.rtg_target / self.rtg_scale] * self.eval_batch_size - self.t = [0] * self.eval_batch_size + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + self.t = [0 for _ in range(self.eval_batch_size)] self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device) - #self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + if not self._cfg.model.continuous: + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) + else: + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) + if self.cfg.env_type == 'atari': + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + else: + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) else: for i in data_id: self.running_rtg[i] = self.rtg_target / self.rtg_scale self.t[i] = 0 self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self.device) - #self.actions[i] = torch.zeros((self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) - self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) - self.states[i] = torch.zeros((self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + if not self._cfg.model.continuous: + self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) + else: + self.actions[i] = torch.zeros((self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) + if self.cfg.env_type == 'atari': + self.states[i] = torch.zeros((self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + else: + self.states[i] = torch.zeros((self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) def get_d4rl_normalized_score(self, score, env_name): - # env_key = env_name.split('-')[0].lower() - # assert env_key in D4RLTrajectoryDataset.REF_MAX_SCORE, \ - # f'no reference score for {env_key} env to calculate d4rl score' - # d4rl_max_score, d4rl_min_score = D4RLTrajectoryDataset.REF_MAX_SCORE, D4RLTrajectoryDataset.REF_MIN_SCORE - # return (score - d4rl_min_score[env_key]) / (d4rl_max_score[env_key] - d4rl_min_score[env_key]) - return 0 + env_key = env_name.split('-')[0].lower() + assert env_key in D4RLTrajectoryDataset.REF_MAX_SCORE, \ + f'no reference score for {env_key} env to calculate d4rl score' + d4rl_max_score, d4rl_min_score = D4RLTrajectoryDataset.REF_MAX_SCORE, D4RLTrajectoryDataset.REF_MIN_SCORE + return (score - d4rl_min_score[env_key]) / (d4rl_max_score[env_key] - d4rl_min_score[env_key]) def _state_dict_learn(self) -> Dict[str, Any]: return { @@ -291,6 +320,11 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._learn_model.load_state_dict(state_dict['model']) # self._target_model.load_state_dict(state_dict['target_model']) self._optimizer.load_state_dict(state_dict['optimizer']) + + def _load_state_dict_eval(self, state_dict: Dict[str, Any]) -> None: + self._eval_model.load_state_dict(state_dict) + # self._target_model.load_state_dict(state_dict['target_model']) + # self._optimizer.load_state_dict(state_dict['optimizer']) def _monitor_vars_learn(self) -> List[str]: return ['cur_lr', 'action_loss'] diff --git a/ding/torch_utils/network/transformer.py b/ding/torch_utils/network/transformer.py index e707134a3f..d93457ba9a 100644 --- a/ding/torch_utils/network/transformer.py +++ b/ding/torch_utils/network/transformer.py @@ -81,6 +81,56 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch return attention +class MaskedCausalAttention(nn.Module): + def __init__(self, h_dim, max_T, n_heads, drop_p): + super().__init__() + + self.n_heads = n_heads + self.max_T = max_T + + self.q_net = nn.Linear(h_dim, h_dim) + self.k_net = nn.Linear(h_dim, h_dim) + self.v_net = nn.Linear(h_dim, h_dim) + + self.proj_net = nn.Linear(h_dim, h_dim) + + self.att_drop = nn.Dropout(drop_p) + self.proj_drop = nn.Dropout(drop_p) + + ones = torch.ones((max_T, max_T)) + mask = torch.tril(ones).view(1, 1, max_T, max_T) + + # register buffer makes sure mask does not get updated + # during backpropagation + self.register_buffer('mask',mask) + + def forward(self, x): + B, T, C = x.shape # batch size, seq length, h_dim * n_heads + + N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim + + # rearrange q, k, v as (B, N, T, D) + q = self.q_net(x).view(B, T, N, D).transpose(1,2) + k = self.k_net(x).view(B, T, N, D).transpose(1,2) + v = self.v_net(x).view(B, T, N, D).transpose(1,2) + + # weights (B, N, T, T) + weights = q @ k.transpose(2,3) / math.sqrt(D) + # causal mask applied to weights + weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf')) + # normalize weights, all -inf -> 0 after softmax + normalized_weights = F.softmax(weights, dim=-1) + + # attention (B, N, T, D) + attention = self.att_drop(normalized_weights @ v) + + # gather heads and project (B, N, T, D) -> (B, T, N*D) + attention = attention.transpose(1, 2).contiguous().view(B,T,N*D) + + out = self.proj_drop(self.proj_net(attention)) + return out + + class TransformerLayer(nn.Module): r""" Overview: diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 16c45f36a5..208c386703 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -329,47 +329,257 @@ class D4RLTrajectoryDataset(Dataset): } def __init__(self, cfg: dict) -> None: - dataset_path = cfg.policy.collect.get('data_path', None) + dataset_path = cfg.policy.collect.data_path rtg_scale = cfg.policy.rtg_scale self.context_len = cfg.policy.context_len - - # load dataset - with open(dataset_path, 'rb') as f: - self.trajectories = pickle.load(f) - - if isinstance(self.trajectories[0], list): - # for our collected dataset, e.g. cartpole/lunarlander case - trajectories_tmp = [] - - original_keys = ['obs', 'next_obs', 'action', 'reward'] - keys = ['observations', 'next_observations', 'actions', 'rewards'] - trajectories_tmp = [ - { - key: np.stack( - [ - self.trajectories[eps_index][transition_index][o_key] - for transition_index in range(len(self.trajectories[eps_index])) - ], - axis=0 - ) for key, o_key in zip(keys, original_keys) - } for eps_index in range(len(self.trajectories)) - ] - self.trajectories = trajectories_tmp - - states = [] - for traj in self.trajectories: - traj_len = traj['observations'].shape[0] - states.append(traj['observations']) - # calculate returns to go and rescale them - traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale - - # used for input normalization - states = np.concatenate(states, axis=0) - self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 - - # normalize states - for traj in self.trajectories: - traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + self.env_type = cfg.policy.env_type + + if 'hdf5' in dataset_path: + try: + import h5py + import collections + except ImportError: + import sys + logging.warning("not found h5py package, please install it trough `pip install h5py ") + sys.exit(1) + dataset = h5py.File(dataset_path, 'r') + + N = dataset['rewards'].shape[0] + data_ = collections.defaultdict(list) + + use_timeouts = False + if 'timeouts' in dataset: + use_timeouts = True + + episode_step = 0 + paths = [] + for i in range(N): + done_bool = bool(dataset['terminals'][i]) + if use_timeouts: + final_timestep = dataset['timeouts'][i] + else: + final_timestep = (episode_step == 1000-1) + for k in ['observations', 'actions', 'rewards', 'terminals']: + data_[k].append(dataset[k][i]) + if done_bool or final_timestep: + episode_step = 0 + episode_data = {} + for k in data_: + episode_data[k] = np.array(data_[k]) + paths.append(episode_data) + data_ = collections.defaultdict(list) + episode_step += 1 + + self.trajectories = paths + + # calculate min len of traj, state mean and variance + # and returns_to_go for all traj + min_len = 10**6 + states = [] + for traj in self.trajectories: + traj_len = traj['observations'].shape[0] + min_len = min(min_len, traj_len) + states.append(traj['observations']) + # calculate returns to go and rescale them + traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale + + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # normalize states + for traj in self.trajectories: + traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + + # self.trajectories = {} + # exp_key = ['rewards', 'terminals', 'timeouts'] + # for k in dataset.keys(): + # logging.info(f'Load {k} data.') + # if k in exp_key: + # self.trajectories[k] = np.expand_dims(dataset[k][:], axis=1) + # else: + # self.trajectories[k] = dataset[k][:] + + # # used for input normalization + # states = np.concatenate(self.trajectories['observations'], axis=0) + # self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # # normalize states + # self.trajectories['observations'] = (self.trajectories['observations'] - self.state_mean) / self.state_std + # self.trajectories['returns_to_go'] = discount_cumsum(self.trajectories['rewards'], 1.0) / rtg_scale + + # datalen = self.trajectories['rewards'].shape[0] + + # use_timeouts = False + # if 'timeouts' in dataset: + # use_timeouts = True + + # data_ = collections.defaultdict(list) + # episode_step = 0 + # trajectories_tmp = [] + # for i in range(datalen): + # done_bool = bool(self.trajectories['terminals'][i]) + # final_timestep = (episode_step == 1000-1) + # for k in ['observations', 'actions', 'returns_to_go']: + # data_[k].append(self.trajectories[k][i]) + # if done_bool or final_timestep: + # episode_step = 0 + # episode_data = {} + # for k in data_: + # episode_data[k] = np.array(data_[k]) + # trajectories_tmp.append(episode_data) + # data_ = collections.defaultdict(list) + # episode_step += 1 + # self.trajectories = trajectories_tmp + elif 'pkl' in dataset_path: + if 'dqn' in dataset_path: + # load dataset + with open(dataset_path, 'rb') as f: + self.trajectories = pickle.load(f) + + if isinstance(self.trajectories[0], list): + # for our collected dataset, e.g. cartpole/lunarlander case + trajectories_tmp = [] + + original_keys = ['obs', 'next_obs', 'action', 'reward'] + keys = ['observations', 'next_observations', 'actions', 'rewards'] + trajectories_tmp = [ + { + key: np.stack( + [ + self.trajectories[eps_index][transition_index][o_key] + for transition_index in range(len(self.trajectories[eps_index])) + ], + axis=0 + ) for key, o_key in zip(keys, original_keys) + } for eps_index in range(len(self.trajectories)) + ] + self.trajectories = trajectories_tmp + + states = [] + for traj in self.trajectories: + # traj_len = traj['observations'].shape[0] + states.append(traj['observations']) + # calculate returns to go and rescale them + traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale + + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # normalize states + for traj in self.trajectories: + traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + else: + # load dataset + with open(dataset_path, 'rb') as f: + self.trajectories = pickle.load(f) + + min_len = 10**6 + states = [] + for traj in self.trajectories: + traj_len = traj['observations'].shape[0] + min_len = min(min_len, traj_len) + states.append(traj['observations']) + # calculate returns to go and rescale them + traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale + + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # normalize states + for traj in self.trajectories: + traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + else: + # -- load data from memory (make more efficient) + obss = [] + actions = [] + returns = [0] + done_idxs = [] + stepwise_returns = [] + + transitions_per_buffer = np.zeros(50, dtype=int) + num_trajectories = 0 + while len(obss) < cfg.policy.num_steps: + buffer_num = np.random.choice(np.arange(50 - cfg.policy.num_buffers, 50), 1)[0] + i = transitions_per_buffer[buffer_num] + print('loading from buffer %d which has %d already loaded' % (buffer_num, i)) + frb = FixedReplayBuffer( + data_dir=cfg.policy.data_dir_prefix + '/1/replay_logs', + replay_suffix=buffer_num, + observation_shape=(84, 84), + stack_size=4, + update_horizon=1, + gamma=0.99, + observation_dtype=np.uint8, + batch_size=32, + replay_capacity=100000) + if frb._loaded_buffers: + done = False + curr_num_transitions = len(obss) + trajectories_to_load = cfg.policy.trajectories_per_buffer + while not done: + states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i]) + states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) + obss += [states] + actions += [ac[0]] + stepwise_returns += [ret[0]] + if terminal[0]: + done_idxs += [len(obss)] + returns += [0] + if trajectories_to_load == 0: + done = True + else: + trajectories_to_load -= 1 + returns[-1] += ret[0] + i += 1 + if i >= 100000: + obss = obss[:curr_num_transitions] + actions = actions[:curr_num_transitions] + stepwise_returns = stepwise_returns[:curr_num_transitions] + returns[-1] = 0 + i = transitions_per_buffer[buffer_num] + done = True + num_trajectories += (cfg.policy.trajectories_per_buffer - trajectories_to_load) + transitions_per_buffer[buffer_num] = i + print('this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' % (i, len(obss), num_trajectories)) + + actions = np.array(actions) + returns = np.array(returns) + stepwise_returns = np.array(stepwise_returns) + done_idxs = np.array(done_idxs) + + # -- create reward-to-go dataset + start_index = 0 + rtg = np.zeros_like(stepwise_returns) + for i in done_idxs: + i = int(i) + curr_traj_returns = stepwise_returns[start_index:i] + for j in range(i-1, start_index-1, -1): # start from i-1 + rtg_j = curr_traj_returns[j-start_index:i-start_index] + rtg[j] = sum(rtg_j) + start_index = i + print('max rtg is %d' % max(rtg)) + + # -- create timestep dataset + start_index = 0 + timesteps = np.zeros(len(actions)+1, dtype=int) + for i in done_idxs: + i = int(i) + timesteps[start_index:i+1] = np.arange(i+1 - start_index) + start_index = i+1 + print('max timestep is %d' % max(timesteps)) + + self.obss = obss + self.actions = actions + self.done_idxs = done_idxs + self.rtgs = rtg + self.timesteps = timesteps + # return obss, actions, returns, done_idxs, rtg, timesteps + + def get_max_timestep(self) -> int: + return max(self.timesteps) def get_state_stats(self) -> Tuple[np.ndarray, np.ndarray]: return deepcopy(self.state_mean), deepcopy(self.state_std) @@ -378,56 +588,135 @@ def get_d4rl_dataset_stats(self, env_d4rl_name: str) -> Dict[str, list]: return self.D4RL_DATASET_STATS[env_d4rl_name] def __len__(self) -> int: - return len(self.trajectories) + if self.env_type != 'atari': + return len(self.trajectories) + else: + return len(self.obss) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - traj = self.trajectories[idx] - traj_len = traj['observations'].shape[0] - - if traj_len >= self.context_len: - # sample random index to slice trajectory - si = np.random.randint(0, traj_len - self.context_len) - - states = torch.from_numpy(traj['observations'][si:si + self.context_len]) - actions = torch.from_numpy(traj['actions'][si:si + self.context_len]) - returns_to_go = torch.from_numpy(traj['returns_to_go'][si:si + self.context_len]) - timesteps = torch.arange(start=si, end=si + self.context_len, step=1) + if self.env_type != 'atari': + traj = self.trajectories[idx] + traj_len = traj['observations'].shape[0] - # all ones since no padding + if traj_len > self.context_len: + # sample random index to slice trajectory + si = np.random.randint(0, traj_len - self.context_len) + + states = torch.from_numpy(traj['observations'][si:si + self.context_len]) + actions = torch.from_numpy(traj['actions'][si:si + self.context_len]) + returns_to_go = torch.from_numpy(traj['returns_to_go'][si:si + self.context_len]) + timesteps = torch.arange(start=si, end=si + self.context_len, step=1) + + # all ones since no padding + traj_mask = torch.ones(self.context_len, dtype=torch.long) + + else: + padding_len = self.context_len - traj_len + + # padding with zeros + states = torch.from_numpy(traj['observations']) + states = torch.cat( + [states, torch.zeros(([padding_len] + list(states.shape[1:])), dtype=states.dtype)], dim=0 + ) + + actions = torch.from_numpy(traj['actions']) + actions = torch.cat( + [actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0 + ) + + returns_to_go = torch.from_numpy(traj['returns_to_go']) + returns_to_go = torch.cat( + [ + returns_to_go, + torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype) + ], + dim=0 + ) + + timesteps = torch.arange(start=0, end=self.context_len, step=1) + + traj_mask = torch.cat( + [torch.ones(traj_len, dtype=torch.long), + torch.zeros(padding_len, dtype=torch.long)], dim=0 + ) + return timesteps, states, actions, returns_to_go, traj_mask + else: + block_size = self.context_len + done_idx = idx + block_size + for i in self.done_idxs: + if i > idx: # first done_idx greater than idx + done_idx = min(int(i), done_idx) + break + idx = done_idx - block_size + states = torch.tensor(np.array(self.obss[idx:done_idx]), dtype=torch.float32).reshape(block_size, -1) # (block_size, 4*84*84) + states = states / 255. + actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) + rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) + timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1) traj_mask = torch.ones(self.context_len, dtype=torch.long) - + + return timesteps, states, actions, rtgs, traj_mask + + +class FixedReplayBuffer(object): + """Object composed of a list of OutofGraphReplayBuffers.""" + + def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg + """Initialize the FixedReplayBuffer class. + Args: + data_dir: str, log Directory from which to load the replay buffer. + replay_suffix: int, If not None, then only load the replay buffer + corresponding to the specific suffix in data directory. + *args: Arbitrary extra arguments. + **kwargs: Arbitrary keyword arguments. + """ + self._args = args + self._kwargs = kwargs + self._data_dir = data_dir + self._loaded_buffers = False + self.add_count = np.array(0) + self._replay_suffix = replay_suffix + if not self._loaded_buffers: + if replay_suffix is not None: + assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' + self.load_single_buffer(replay_suffix) else: - padding_len = self.context_len - traj_len - - # padding with zeros - states = torch.from_numpy(traj['observations']) - states = torch.cat( - [states, torch.zeros(([padding_len] + list(states.shape[1:])), dtype=states.dtype)], dim=0 - ) - - actions = torch.from_numpy(traj['actions']) - actions = torch.cat( - [actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0 - ) - - returns_to_go = torch.from_numpy(traj['returns_to_go']) - returns_to_go = torch.cat( - [ - returns_to_go, - torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype) - ], - dim=0 - ) - - timesteps = torch.arange(start=0, end=self.context_len, step=1) - - traj_mask = torch.cat( - [torch.ones(traj_len, dtype=torch.long), - torch.zeros(padding_len, dtype=torch.long)], dim=0 - ) - - return timesteps, states, actions, returns_to_go, traj_mask - + pass + # self._load_replay_buffers(num_buffers=50) + + def load_single_buffer(self, suffix): + """Load a single replay buffer.""" + replay_buffer = self._load_buffer(suffix) + if replay_buffer is not None: + self._replay_buffers = [replay_buffer] + self.add_count = replay_buffer.add_count + self._num_replay_buffers = 1 + self._loaded_buffers = True + + def _load_buffer(self, suffix): + """Loads a OutOfGraphReplayBuffer replay buffer.""" + try: + from dopamine.replay_memory import circular_replay_buffer + STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX + # pytype: disable=attribute-error + replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer( + *self._args, **self._kwargs) + replay_buffer.load(self._data_dir, suffix) + print('Loaded replay buffer ckpt {} from {}'.format( + suffix, self._data_dir)) + # pytype: enable=attribute-error + return replay_buffer + # except tf.errors.NotFoundError: + except: + raise('can not load') + + def get_transition_elements(self): + return self._replay_buffers[0].get_transition_elements() + + def sample_transition_batch(self, batch_size=None, indices=None): + buffer_index = np.random.randint(self._num_replay_buffers) + return self._replay_buffers[buffer_index].sample_transition_batch( + batch_size=batch_size, indices=indices) class PCDataset(Dataset): diff --git a/dizoo/atari/config/pong_dt_config.py b/dizoo/atari/config/pong_dt_config.py new file mode 100644 index 0000000000..2bf3de48c9 --- /dev/null +++ b/dizoo/atari/config/pong_dt_config.py @@ -0,0 +1,103 @@ +from easydict import EasyDict +from copy import deepcopy + +hopper_dt_config = dict( + exp_name='dt_log/atari/Pong/Pong_dt_seed0', + env=dict( + env_id='PongNoFrameskip-v4', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=20, + frame_stack=4, + is_train=False, + ), + policy=dict( + num_buffers=50, + num_steps=500000, + data_dir_prefix='/mnt/nfs/luyd/d4rl_atari/Pong/', + trajectories_per_buffer=10, + env_type='atari', + stop_value=20, + state_mean=None, + state_std=None, + evaluator_env_num=8, + cuda=True, + env_name='PongNoFrameskip-v4', + dataset_name='Pong', + rtg_target=20, # max target return to go + rtg_scale=10, + max_eval_ep_len=10000, # max lenght of one episode + num_eval_ep=10, # num of evaluation episode + wt_decay=1e-4, + # warmup_steps=100000, + warmup_steps=10000, + num_updates_per_iter=100, + context_len=30, + n_blocks=6, + embed_dim=128, + n_heads=8, + dropout_p=0.1, + model=dict( + state_dim=(4, 84, 84), + act_dim=6, + n_blocks=3, + h_dim=128, + n_embd=128, + context_len=30, + n_heads=8, + n_layer=6, + drop_p=0.1, + embd_pdrop=0.1, + resid_pdrop = 0.1, + attn_pdrop = 0.1, + continuous=False, + ), + discount_factor=0.999, + nstep=3, + learn=dict( + batch_size=128, + learning_rate=6e-4, + target_update_freq=100, + kappa=1.0, + min_q_weight=4.0, + ), + collect=dict( + data_type='d4rl_trajectory', + # data_path='/mnt/nfs/luyd/hopper_medium.hdf5', + data_path='/mnt/nfs/luyd/d4rl_atari/Pong', + unroll_len=1, + ), + eval=dict(evaluator=dict(evalu_freq=100, ), ), + other=dict( + eps=dict( + type='exp', + start=0.95, + end=0.1, + decay=10000, + ), + replay_buffer=dict(replay_buffer_size=1000, ), + ), + ), +) + +hopper_dt_config = EasyDict(hopper_dt_config) +main_config = hopper_dt_config +hopper_dt_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dt'), +) +hopper_dt_create_config = EasyDict(hopper_dt_create_config) +create_config = hopper_dt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_dt + config = deepcopy([main_config, create_config]) + serial_pipeline_dt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py b/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py index d225b299ec..1847f7b7d4 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py @@ -6,16 +6,18 @@ exp_name='data_dt/lunarlander_dt_1000eps_rtgt300_meel1000_seed0_debug', env=dict( env_id='LunarLander-v2', - collector_env_num=8, evaluator_env_num=8, n_evaluator_episode=8, stop_value=200, ), policy=dict( stop_value=200, + state_mean=None, + state_std=None, device='cuda', env_name='LunarLander-v2', rtg_target=300, # max target reward_to_go + rtg_scale=150, max_eval_ep_len=1000, # max len of one episode # TODO num_eval_ep=10, # num of evaluation episodes wt_decay=1e-4, @@ -27,6 +29,7 @@ embed_dim=128, n_heads=1, dropout_p=0.1, + log_dir='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', model=dict( state_dim=8, act_dim=4, @@ -40,7 +43,8 @@ discount_factor=0.999, nstep=3, learn=dict( - learning_rate=1e-4, + dataset_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO + learning_rate=3e-4, batch_size=64, # training batch size target_update_freq=100, kappa=1.0, diff --git a/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-50-45.csv b/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-50-45.csv deleted file mode 100644 index a03d532276..0000000000 --- a/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-50-45.csv +++ /dev/null @@ -1 +0,0 @@ -duration,num_updates,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score diff --git a/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-57-28.csv b/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-57-28.csv deleted file mode 100644 index 291f905b5d..0000000000 --- a/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-08-57-28.csv +++ /dev/null @@ -1,2 +0,0 @@ -duration,num_updates,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score -0:02:06,1000,-210.51245402009764,213.5,0 diff --git a/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-09-09-38.csv b/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-09-09-38.csv deleted file mode 100644 index 2ebd2b172d..0000000000 --- a/dizoo/box2d/lunarlander/dt_log_1000eps/dt_LunarLander-v2_log_23-07-13-09-09-38.csv +++ /dev/null @@ -1,16 +0,0 @@ -duration,num_updates,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score -0:02:01,1000,-318.0265498211703,196.6,0 -0:04:20,2000,41.14780414978246,718.4,0 -0:06:42,3000,142.82828392440672,762.4,0 -0:09:02,4000,138.89283069755604,724.3,0 -0:11:23,5000,107.84010045187748,767.2,0 -0:13:44,6000,160.2343642862316,701.2,0 -0:16:00,7000,121.66243934822947,654.6,0 -0:18:22,8000,77.69263487376318,720.2,0 -0:20:37,9000,198.9464703222405,633.8,0 -0:22:59,10000,81.862677775086,728.9,0 -0:25:17,11000,167.60164353189074,671.1,0 -0:27:33,12000,180.91905256798407,634.4,0 -0:29:47,13000,185.41585978196196,563.7,0 -0:32:06,14000,190.83281151114906,653.6,0 -0:34:20,15000,205.16962772579026,561.3,0 diff --git a/dizoo/d4rl/config/__init__.py b/dizoo/d4rl/config/__init__.py index 450f069392..92bca79cc2 100644 --- a/dizoo/d4rl/config/__init__.py +++ b/dizoo/d4rl/config/__init__.py @@ -1,3 +1,3 @@ -from .hopper_cql_config import hopper_cql_config -from .hopper_expert_cql_config import hopper_expert_cql_config -from .hopper_medium_cql_config import hopper_medium_cql_config +# from .hopper_cql_config import hopper_cql_config +# from .hopper_expert_cql_config import hopper_expert_cql_config +# from .hopper_medium_cql_config import hopper_medium_cql_config diff --git a/dizoo/d4rl/config/hopper_expert_dt_config.py b/dizoo/d4rl/config/hopper_expert_dt_config.py index 592d9411df..7180ddc717 100644 --- a/dizoo/d4rl/config/hopper_expert_dt_config.py +++ b/dizoo/d4rl/config/hopper_expert_dt_config.py @@ -2,7 +2,7 @@ from copy import deepcopy hopper_dt_config = dict( - exp_name='hopper_expert_dt_seed0', + exp_name='dt_log/d4rl/hopper/hopper_expert_dt_seed0', env=dict( env_id='Hopper-v3', norm_obs=dict(use_norm=False, ), @@ -15,12 +15,14 @@ ), policy=dict( stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, cuda=True, env_name='Hopper-v3', rtg_target=6000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, num_updates_per_iter=100, @@ -29,7 +31,6 @@ embed_dim=128, n_heads=1, dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/hopper_expert_dt_log', model=dict( state_dim=11, act_dim=3, @@ -43,13 +44,17 @@ discount_factor=0.999, nstep=3, learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/hopper-expert-v2.pkl', + batch_size=64, learning_rate=0.0001, target_update_freq=100, kappa=1.0, min_q_weight=4.0, ), - collect=dict(unroll_len=1, ), + collect=dict( + data_type='d4rl_trajectory', + data_path='/mnt/nfs/luyd/hopper_expert.hdf5', + unroll_len=1, + ), eval=dict(evaluator=dict(evalu_freq=100, ), ), other=dict( eps=dict( diff --git a/dizoo/d4rl/config/hopper_medium_dt_config.py b/dizoo/d4rl/config/hopper_medium_dt_config.py index ae3778a1d8..37e804da04 100644 --- a/dizoo/d4rl/config/hopper_medium_dt_config.py +++ b/dizoo/d4rl/config/hopper_medium_dt_config.py @@ -2,7 +2,7 @@ from copy import deepcopy hopper_dt_config = dict( - exp_name='hopper_medium_dt_seed0', + exp_name='dt_log/d4rl/hopper/hopper_medium_dt_seed0', env=dict( env_id='Hopper-v3', norm_obs=dict(use_norm=False, ), @@ -11,17 +11,22 @@ evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, - stop_value=6000, + stop_value=3600, ), policy=dict( - stop_value=6000, + stop_value=3600, + state_mean=None, + state_std=None, + evaluator_env_num=8, cuda=True, env_name='Hopper-v3', - rtg_target=6000, # max target return to go + dataset_name='hopper-medium-v2', + rtg_target=3600, # max target return to go + rtg_scale=1000, max_eval_ep_len=1000, # max lenght of one episode num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, + # warmup_steps=100000, warmup_steps=10000, num_updates_per_iter=100, context_len=20, @@ -29,7 +34,6 @@ embed_dim=128, n_heads=1, dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/hopper_medium_dt_log', model=dict( state_dim=11, act_dim=3, @@ -38,18 +42,24 @@ context_len=20, n_heads=1, drop_p=0.1, + max_timestep=0, continuous=True, ), discount_factor=0.999, nstep=3, learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/hopper-medium-v2.pkl', - learning_rate=0.0001, + batch_size=64, + learning_rate=1e-4, target_update_freq=100, kappa=1.0, min_q_weight=4.0, ), - collect=dict(unroll_len=1, ), + collect=dict( + data_type='d4rl_trajectory', + # data_path='/mnt/nfs/luyd/hopper_medium.hdf5', + data_path='/mnt/nfs/luyd/d4rl/hopper_medium-v2.pkl', + unroll_len=1, + ), eval=dict(evaluator=dict(evalu_freq=100, ), ), other=dict( eps=dict( diff --git a/dizoo/d4rl/config/hopper_medium_expert_dt_config.py b/dizoo/d4rl/config/hopper_medium_expert_dt_config.py index cab3b10353..869ea10d84 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_dt_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_dt_config.py @@ -2,7 +2,7 @@ from copy import deepcopy hopper_dt_config = dict( - exp_name='hopper_medium_expert_dt_seed0', + exp_name='dt_log/d4rl/hopper/hopper_medium_expert_dt_seed0', env=dict( env_id='Hopper-v3', norm_obs=dict(use_norm=False, ), @@ -15,12 +15,14 @@ ), policy=dict( stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, cuda=True, env_name='Hopper-v3', rtg_target=6000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, num_updates_per_iter=100, @@ -29,7 +31,6 @@ embed_dim=128, n_heads=1, dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/hopper_medium_expert_dt_log', model=dict( state_dim=11, act_dim=3, @@ -43,13 +44,17 @@ discount_factor=0.999, nstep=3, learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/hopper-medium-expert-v2.pkl', + batch_size=64, learning_rate=0.0001, target_update_freq=100, kappa=1.0, min_q_weight=4.0, ), - collect=dict(unroll_len=1, ), + collect=dict( + data_type='d4rl_trajectory', + data_path='/mnt/nfs/luyd/d4rl/hopper_medium_expert.hdf5', + unroll_len=1, + ), eval=dict(evaluator=dict(evalu_freq=100, ), ), other=dict( eps=dict( From 8b330a6078d2d044ad6f74b2390535d0ed749cdc Mon Sep 17 00:00:00 2001 From: luyudong Date: Tue, 25 Jul 2023 17:23:21 +0800 Subject: [PATCH 05/25] Fix according to comment --- ding/envs/env_wrappers/env_wrappers.py | 11 +++++++++++ ding/policy/command_mode_policy_instance.py | 1 - ding/policy/dt.py | 13 +++++++------ .../config/lunarlander_decision_transformer.py | 4 ++-- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/ding/envs/env_wrappers/env_wrappers.py b/ding/envs/env_wrappers/env_wrappers.py index 9a60cb95d5..3b715a0689 100644 --- a/ding/envs/env_wrappers/env_wrappers.py +++ b/ding/envs/env_wrappers/env_wrappers.py @@ -1175,6 +1175,17 @@ def reset(self): class AllinObsWrapper(gym.Wrapper): + """ + Overview: + This wrapper is used in policy DT. + Set a dict {'obs': obs, 'reward': reward} + as the new wrapped observation, + which including the current obs, previous reward. + Interface: + ``__init__``, ``reset``, ``step``, ``seed`` + Properties: + - env (:obj:`gym.Env`): the environment to wrap. + """ def __init__(self, env): super().__init__(env) diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 05b2423082..8abc905324 100755 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -42,7 +42,6 @@ from .d4pg import D4PGPolicy from .cql import CQLPolicy, CQLDiscretePolicy -# from .decision_transformer import DTPolicy from .dt import DTPolicy from .pdqn import PDQNPolicy from .sac import SQILSACPolicy diff --git a/ding/policy/dt.py b/ding/policy/dt.py index 2944e9b93f..f5142053ca 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -28,7 +28,8 @@ class DTPolicy(Policy): r""" Overview: - Policy class of DT algorithm in discrete environments. + Policy class of Decision Transformer algorithm in discrete environments. + Paper link: https://arxiv.org/abs/2106.01345 """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -80,10 +81,10 @@ def default_model(self) -> Tuple[str, List[str]]: def _init_learn(self) -> None: r""" - Overview: - Learn mode init method. Called by ``self.__init__``. - Init the optimizer, algorithm config, main and target models. - """ + Overview: + Learn mode init method. Called by ``self.__init__``. + Init the optimizer, algorithm config, main and target models. + """ self.env_name = self._cfg.env_name # rtg_scale: scale of `return to go` # rtg_target: max target of `return to go` @@ -151,7 +152,7 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: state_preds, action_preds, return_preds = self._learn_model.forward( timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go ) - + if self.cfg.env_type == 'atari': action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) else: diff --git a/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py b/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py index 3d9a4c5bf5..1cd9ed2018 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py @@ -27,7 +27,7 @@ embed_dim=128, n_heads=1, dropout_p=0.1, - log_dir='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', + log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', model=dict( state_dim=8, act_dim=4, @@ -41,7 +41,7 @@ discount_factor=0.999, nstep=3, learn=dict( - dataset_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO + dataset_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO learning_rate=1e-4, target_update_freq=100, kappa=1.0, From 1ccb2ece7ee36a4242ea8e4254ff2cef86b58b50 Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 31 Jul 2023 13:27:02 +0800 Subject: [PATCH 06/25] Fix dt config files --- ding/envs/env/tests/test_ding_env_wrapper.py | 19 + ding/envs/env_wrappers/env_wrappers.py | 1 + .../middleware/functional/data_processor.py | 17 +- ding/model/template/__init__.py | 2 +- ding/model/template/dt.py | 130 +++++- ding/policy/decision_transformer.py | 428 ------------------ ding/policy/dt.py | 99 ++-- ding/utils/data/dataset.py | 15 +- dizoo/atari/config/pong_dt_config.py | 36 +- .../atari/entry/atari_dt_main.py | 0 .../config/lunarlander_dt_config.py | 16 +- dizoo/d4rl/config/hopper_medium_dt_config.py | 5 +- .../d4rl/entry/d4rl_dt_mujoco.py | 0 13 files changed, 239 insertions(+), 529 deletions(-) delete mode 100644 ding/policy/decision_transformer.py rename ding/example/dt_atari.py => dizoo/atari/entry/atari_dt_main.py (100%) rename ding/example/dt_mujoco.py => dizoo/d4rl/entry/d4rl_dt_mujoco.py (100%) diff --git a/ding/envs/env/tests/test_ding_env_wrapper.py b/ding/envs/env/tests/test_ding_env_wrapper.py index 03c97e437b..5b98a9403c 100644 --- a/ding/envs/env/tests/test_ding_env_wrapper.py +++ b/ding/envs/env/tests/test_ding_env_wrapper.py @@ -180,3 +180,22 @@ def test_hybrid(self): action = ding_env_hybrid.random_action() print('random_action', action) assert isinstance(action, dict) + + @pytest.mark.unittest + def test_AllinObsWrapper(self): + env_cfg = EasyDict(env_id='PongNoFrameskip-v4', env_wrapper='reward_in_obs') + ding_env_aio = DingEnvWrapper(cfg=env_cfg) + + data = ding_env_aio.reset() + assert isinstance(data, dict) + assert 'obs' in data.keys() and 'reward' in data.keys() + assert data['obs'].shape == ding_env_aio.observation_space + while True: + action = ding_env_aio.random_action() + timestep = ding_env_aio.step(action) + # print(timestep.reward) + assert isinstance(timestep.obs,dict) + if timestep.done: + assert 'eval_episode_return' in timestep.info, timestep.info + break + print(ding_env_aio.observation_space, ding_env_aio.action_space, ding_env_aio.reward_space) diff --git a/ding/envs/env_wrappers/env_wrappers.py b/ding/envs/env_wrappers/env_wrappers.py index 3b715a0689..1d5752dd95 100644 --- a/ding/envs/env_wrappers/env_wrappers.py +++ b/ding/envs/env_wrappers/env_wrappers.py @@ -1174,6 +1174,7 @@ def reset(self): return self.env.reset() +@ENV_WRAPPER_REGISTRY.register('reward_in_obs') class AllinObsWrapper(gym.Wrapper): """ Overview: diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 1dc9429458..f551cba7ad 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -192,7 +192,7 @@ def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable: - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data. """ # collate_fn is executed in policy now - dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x) + dataloader = iter(DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)) def _fetch(ctx: "OfflineRLContext"): """ @@ -204,11 +204,18 @@ def _fetch(ctx: "OfflineRLContext"): Output of ctx: - train_data (:obj:`List[Tensor]`): The fetched data batch. """ - while True: - for i, data in enumerate(dataloader): - ctx.train_data = data - yield + nonlocal dataloader + try: + ctx.train_data = next(dataloader) + except StopIteration: ctx.train_epoch += 1 + del dataloader + dataloader = iter(DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)) + ctx.train_data = next(dataloader) + # for i, data in enumerate(dataloader): + # ctx.train_data = data + # yield + # ctx.train_epoch += 1 # TODO apply data update (e.g. priority) in offline setting when necessary return _fetch diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index dc3ff9d5b4..411be9673a 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -21,7 +21,7 @@ from .maqac import MAQAC, ContinuousMAQAC from .madqn import MADQN from .vae import VanillaVAE -from .decision_transformer import DecisionTransformer +from .dt import DecisionTransformer, DecisionTransformerA from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS from .bcq import BCQ from .edac import QACEnsemble diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index d6d39e74fd..3b57d3e461 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -165,13 +165,83 @@ def forward(self, timesteps, states, actions, returns_to_go): return state_preds, action_preds, return_preds +class GELU(nn.Module): + def forward(self, input): + return F.gelu(input) + + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, n_head, block_size, n_embd, attn_pdrop, resid_pdrop): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(n_embd, n_embd) + self.query = nn.Linear(n_embd, n_embd) + self.value = nn.Linear(n_embd, n_embd) + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.resid_drop = nn.Dropout(resid_pdrop) + # output projection + self.proj = nn.Linear(n_embd, n_embd) + # causal mask to ensure that attention is only applied to the left in the input sequence + # self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) + self.register_buffer("mask", torch.tril(torch.ones(block_size + 1, block_size + 1)).view(1, 1, block_size + 1, block_size + 1)) + self.n_head = n_head + + def forward(self, x, layer_past=None): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + + +class BlockA(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_head, block_size, n_embd, attn_pdrop, resid_pdrop): + super().__init__() + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.attn = CausalSelfAttention(n_head, block_size, n_embd, attn_pdrop, resid_pdrop) + self.mlp = nn.Sequential( + nn.Linear(n_embd, 4 * n_embd), + GELU(), + nn.Linear(4 * n_embd, n_embd), + nn.Dropout(resid_pdrop), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + + class DecisionTransformerA(nn.Module): def __init__(self, config): super().__init__() self.config = config self.n_embd = config.n_embd self.block_size = config.context_len * 3 - h_dim = config.h_dim # input embedding stem self.tok_emb = nn.Embedding(config.act_dim, config.n_embd) @@ -181,7 +251,7 @@ def __init__(self, config): self.drop = nn.Dropout(config.embd_pdrop) # transformer - self.blocks = nn.Sequential(*[Block(h_dim, self.block_size, config.n_heads, config.drop_p) for _ in range(config.n_layer)]) + self.blocks = nn.Sequential(*[BlockA(config.n_heads, self.block_size, config.n_embd, config.attn_pdrop, config.resid_pdrop) for _ in range(config.n_layer)]) # decoder head self.ln_f = nn.LayerNorm(config.n_embd) self.head = nn.Linear(config.n_embd, config.act_dim, bias=False) @@ -197,7 +267,7 @@ def __init__(self, config): # state, action, and return - def forward(self, timesteps, states, actions, returns_to_go): + def forward(self, timesteps, states, actions, returns_to_go, tar=None): # states: (batch, block_size, 4*84*84) # actions: (batch, block_size, 1) # rtgs: (batch, block_size, 1) @@ -209,10 +279,10 @@ def forward(self, timesteps, states, actions, returns_to_go): rtg_embeddings = self.ret_emb(rtgs.type(torch.float32)) action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd) - token_embeddings = torch.zeros((states.shape[0], states.shape[1]*3 - 1, self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) + token_embeddings = torch.zeros((states.shape[0], states.shape[1]*3 - int(tar is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) token_embeddings[:,::3,:] = rtg_embeddings token_embeddings[:,1::3,:] = state_embeddings - token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] + 1:,:] + token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] + int(tar is None):,:] batch_size = states.shape[0] all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd @@ -226,4 +296,52 @@ def forward(self, timesteps, states, actions, returns_to_go): logits = logits[:, 1::3, :] # only keep predictions from state_embeddings - return None, logits, None \ No newline at end of file + return None, logits, None + + def configure_optimizers(self, weight_decay, learning_rate, betas): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + # whitelist_weight_modules = (torch.nn.Linear, ) + whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add('pos_emb') + no_decay.add('global_pos_emb') + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer \ No newline at end of file diff --git a/ding/policy/decision_transformer.py b/ding/policy/decision_transformer.py deleted file mode 100644 index aeb08e1dbf..0000000000 --- a/ding/policy/decision_transformer.py +++ /dev/null @@ -1,428 +0,0 @@ -"""The code is adapted from https://github.com/nikhilbarhate99/min-decision-transformer -""" - -from typing import List, Dict, Any, Tuple, Union -from collections import namedtuple -from torch.distributions import Normal, Independent -from ding.torch_utils import Adam, to_device -from ditk import logging -from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ - qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data -from ding.model import model_wrap -from ding.utils.data.dataset import D4RLTrajectoryDataset -from ding.utils import POLICY_REGISTRY -from ding.utils.data import default_collate, default_decollate -from datetime import datetime -from ding.torch_utils import one_hot -import numpy as np -import torch.nn.functional as F -import torch -import gym -import copy -import os -import csv -from .base_policy import Policy - - -@POLICY_REGISTRY.register('dt') -class DTPolicy(Policy): - r""" - Overview: - Policy class of DT algorithm in discrete environments. - """ - config = dict( - # (str) RL policy register name (refer to function "POLICY_REGISTRY"). - type='dt', - # (bool) Whether to use cuda for network. - cuda=False, - # (bool) Whether the RL algorithm is on-policy or off-policy. - on_policy=False, - # (bool) Whether use priority(priority sample, IS weight, update priority) - priority=False, - # (float) Reward's future discount factor, aka. gamma. - discount_factor=0.97, - # (int) N-step reward for target q_value estimation - nstep=1, - obs_shape=4, - action_shape=2, - # encoder_hidden_size_list=[128, 128, 64], - dataset='medium', # medium / medium-replay / medium-expert - rtg_scale=1000, # normalize returns to go - max_eval_ep_len=1000, # max len of one episode - num_eval_ep=10, # num of evaluation episodes - batch_size=64, # training batch size - wt_decay=1e-4, - warmup_steps=10000, - max_train_iters=200, - context_len=20, - n_blocks=3, - embed_dim=128, - dropout_p=0.1, - learn=dict( - - # batch_size=64, - learning_rate=1e-4, - # ============================================================== - # The following configs are algorithm-specific - # ============================================================== - ), - # collect_mode config - collect=dict(), - eval=dict(), - # other config - other=dict(), - ) - - def default_model(self) -> Tuple[str, List[str]]: - return 'dt', ['ding.model.template.decision_transformer'] - - def _init_learn(self) -> None: - r""" - Overview: - Learn mode init method. Called by ``self.__init__``. - Init the optimizer, algorithm config, main and target models. - """ - - self.stop_value = self._cfg.stop_value - self.env_name = self._cfg.env_name - dataset = self._cfg.dataset # medium / medium-replay / medium-expert - # rtg_scale: scale of `return to go` - # rtg_target: max target of `return to go` - # Our goal is normalize `return to go` to (0, 1), which will favour the covergence. - # As a result, we usually set rtg_scale == rtg_target. - self.rtg_scale = self._cfg.rtg_target # normalize returns to go - self.rtg_target = self._cfg.rtg_target # max target reward_to_go - self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode - self.num_eval_ep = self._cfg.num_eval_ep # num of evaluation episodes - - lr = self._cfg.learn.learning_rate # learning rate - wt_decay = self._cfg.wt_decay # weight decay - warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler - - max_train_iters = self._cfg.max_train_iters - - self.context_len = self._cfg.context_len # K in decision transformer - n_blocks = self._cfg.n_blocks # num of transformer blocks - embed_dim = self._cfg.embed_dim # embedding (hidden) dim of transformer - dropout_p = self._cfg.dropout_p # dropout probability - - # # load data from this file - # dataset_path = f'{self._cfg.dataset_dir}/{env_d4rl_name}.pkl' - - # saves model and csv in this directory - self.log_dir = self._cfg.log_dir - if not os.path.exists(self.log_dir): - os.makedirs(self.log_dir) - - # training and evaluation device - self.device = torch.device(self._device) - - self.start_time = datetime.now().replace(microsecond=0) - self.start_time_str = self.start_time.strftime("%y-%m-%d-%H-%M-%S") - - # prefix = "dt_" + env_d4rl_name - self.prefix = "dt_" + self.env_name - - save_model_name = self.prefix + "_model_" + self.start_time_str + ".pt" - self.save_model_path = os.path.join(self.log_dir, save_model_name) - self.save_best_model_path = self.save_model_path[:-3] + "_best.pt" - - log_csv_name = self.prefix + "_log_" + self.start_time_str + ".csv" - log_csv_path = os.path.join(self.log_dir, log_csv_name) - - self.csv_writer = csv.writer(open(log_csv_path, 'a', 1)) - csv_header = (["duration", "num_updates", "eval_avg_reward", "eval_avg_ep_len", "eval_d4rl_score"]) - - self.csv_writer.writerow(csv_header) - - dataset_path = self._cfg.learn.dataset_path - logging.info("=" * 60) - logging.info("start time: " + self.start_time_str) - logging.info("=" * 60) - - logging.info("device set to: " + str(self.device)) - logging.info("dataset path: " + dataset_path) - logging.info("model save path: " + self.save_model_path) - logging.info("log csv save path: " + log_csv_path) - - self._env = gym.make(self.env_name) - - self.state_dim = self._cfg.model.state_dim - self.act_dim = self._cfg.model.act_dim - - self._learn_model = self._model - self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) - - self._scheduler = torch.optim.lr_scheduler.LambdaLR( - self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1) - ) - - self.max_env_score = -1.0 - - def _forward_learn(self, data: list) -> Dict[str, Any]: - r""" - Overview: - Forward and backward function of learn mode. - Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] - Returns: - - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. - """ - - self._learn_model.train() - - timesteps, states, actions, returns_to_go, traj_mask = data - - timesteps = timesteps.to(self.device) # B x T - states = states.to(self.device) # B x T x state_dim - actions = actions.to(self.device) # B x T x act_dim - returns_to_go = returns_to_go.to(self.device) # B x T x 1 - traj_mask = traj_mask.to(self.device) # B x T - action_target = torch.clone(actions).detach().to(self.device) - - # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), - # and we need a 3-dim tensor - if len(returns_to_go.shape) == 2: - returns_to_go = returns_to_go.unsqueeze(-1) - - # if discrete - if not self._cfg.model.continuous: - actions = one_hot(actions.squeeze(-1), num=self.act_dim) - - state_preds, action_preds, return_preds = self._learn_model.forward( - timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go - ) - - traj_mask = traj_mask.view(-1, ) - - # only consider non padded elements - action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] - - if self._cfg.model.continuous: - action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] - else: - action_target = action_target.view(-1)[traj_mask > 0] - - if self._cfg.model.continuous: - action_loss = F.mse_loss(action_preds, action_target) - else: - action_loss = F.cross_entropy(action_preds, action_target) - - self._optimizer.zero_grad() - action_loss.backward() - torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), 0.25) - self._optimizer.step() - self._scheduler.step() - - return { - 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'], - 'action_loss': action_loss.detach().cpu().item(), - } - - def evaluate_on_env(self, state_mean=None, state_std=None, render=False): - - eval_batch_size = 1 # required for forward pass - - results = {} - total_reward = 0 - total_timesteps = 0 - - # state_dim = env.observation_space.shape[0] - # act_dim = env.action_space.shape[0] - - if state_mean is None: - self.state_mean = torch.zeros((self.state_dim, )).to(self.device) - else: - self.state_mean = torch.from_numpy(state_mean).to(self.device) - - if state_std is None: - self.state_std = torch.ones((self.state_dim, )).to(self.device) - else: - self.state_std = torch.from_numpy(state_std).to(self.device) - - # same as timesteps used for training the transformer - # also, crashes if device is passed to arange() - timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1) - timesteps = timesteps.repeat(eval_batch_size, 1).to(self.device) - - self._learn_model.eval() - - with torch.no_grad(): - - for _ in range(self.num_eval_ep): - - # zeros place holders - # continuous action - actions = torch.zeros( - (eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device - ) - - # discrete action # TODO - # actions = torch.randint(0,self.act_dim,[eval_batch_size, self.max_eval_ep_len, 1], - # dtype=torch.long, device=self.device) - - states = torch.zeros( - (eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device - ) - rewards_to_go = torch.zeros( - (eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device - ) - - # init episode - running_state = self._env.reset() - running_reward = 0 - running_rtg = self.rtg_target / self.rtg_scale - - for t in range(self.max_eval_ep_len): - - total_timesteps += 1 - - # add state in placeholder and normalize - states[0, t] = torch.from_numpy(running_state).to(self.device) - # states[0, t] = (states[0, t].cpu() - self.state_mean.cpu().numpy()) / self.state_std.cpu().numpy() - states[0, t] = (states[0, t] - self.state_mean) / self.state_std - - # calcualate running rtg and add it in placeholder - running_rtg = running_rtg - (running_reward / self.rtg_scale) - rewards_to_go[0, t] = running_rtg - - if t < self.context_len: - _, act_preds, _ = self._learn_model.forward( - timesteps[:, :self.context_len], states[:, :self.context_len], - actions[:, :self.context_len], rewards_to_go[:, :self.context_len] - ) - act = act_preds[0, t].detach() - else: - _, act_preds, _ = self._learn_model.forward( - timesteps[:, t - self.context_len + 1:t + 1], states[:, t - self.context_len + 1:t + 1], - actions[:, t - self.context_len + 1:t + 1], rewards_to_go[:, t - self.context_len + 1:t + 1] - ) - act = act_preds[0, -1].detach() - - # if discrete - if not self._cfg.model.continuous: - act = torch.argmax(act) - running_state, running_reward, done, _ = self._env.step(act.cpu().numpy()) - - # add action in placeholder - actions[0, t] = act - - total_reward += running_reward - - if render: - self._env.render() - if done: - break - - results['eval/avg_reward'] = total_reward / self.num_eval_ep - results['eval/avg_ep_len'] = total_timesteps / self.num_eval_ep - - return results - - def evaluate(self, total_update_times, state_mean=None, state_std=None, render=False): - results = self.evaluate_on_env(state_mean, state_std, render) - - eval_avg_reward = results['eval/avg_reward'] - eval_avg_ep_len = results['eval/avg_ep_len'] - eval_d4rl_score = self.get_d4rl_normalized_score(results['eval/avg_reward'], self.env_name) * 100 - - time_elapsed = str(datetime.now().replace(microsecond=0) - self.start_time) - - log_str = ( - "=" * 60 + '\n' + "time elapsed: " + time_elapsed + '\n' + "num of updates: " + str(total_update_times) + - '\n' + '\n' + "eval avg reward: " + format(eval_avg_reward, ".5f") + '\n' + "eval avg ep len: " + - format(eval_avg_ep_len, ".5f") + '\n' + "eval d4rl score: " + format(eval_d4rl_score, ".5f") - ) - - logging.info(log_str) - - log_data = [time_elapsed, total_update_times, eval_avg_reward, eval_avg_ep_len, eval_d4rl_score] - log_csv_name = self.prefix + "_log_" + self.start_time_str + ".csv" - log_csv_path = os.path.join(self.log_dir, log_csv_name) - - self.csv_writer.writerow(log_data) - - # save model - logging.info("eval_avg_reward: " + format(eval_avg_reward, ".5f")) - eval_env_score = eval_avg_reward - if eval_env_score >= self.max_env_score: - logging.info("saving max env score model at: " + self.save_best_model_path) - torch.save(self._learn_model.state_dict(), self.save_best_model_path) - self.max_env_score = eval_env_score - - logging.info("saving current model at: " + self.save_model_path) - torch.save(self._learn_model.state_dict(), self.save_model_path) - - return self.max_env_score >= self.stop_value - - def get_d4rl_normalized_score(self, score, env_name): - # env_key = env_name.split('-')[0].lower() - # assert env_key in D4RLTrajectoryDataset.REF_MAX_SCORE, \ - # f'no reference score for {env_key} env to calculate d4rl score' - # d4rl_max_score, d4rl_min_score = D4RLTrajectoryDataset.REF_MAX_SCORE, D4RLTrajectoryDataset.REF_MIN_SCORE - # return (score - d4rl_min_score[env_key]) / (d4rl_max_score[env_key] - d4rl_min_score[env_key]) - return 0 - - def _state_dict_learn(self) -> Dict[str, Any]: - return { - 'model': self._learn_model.state_dict(), - # 'target_model': self._target_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - self._learn_model.load_state_dict(state_dict['model']) - # self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _monitor_vars_learn(self) -> List[str]: - return ['cur_lr', 'action_loss'] - - - def _init_eval(self) -> None: - pass - - - def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: - pass - - - def _init_collect(self) -> None: - pass - - def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: - pass - - def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Overview: - For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ - can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ - or some continuous transitions(DRQN). - Arguments: - - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ - format as the return value of ``self._process_transition`` method. - Returns: - - samples (:obj:`dict`): The list of training samples. - - .. note:: - We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ - And the user can customize the this data processing procecure by overriding this two methods and collector \ - itself. - """ - pass - - def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: - """ - Overview: - Generate a transition(e.g.: ) for this algorithm training. - Arguments: - - obs (:obj:`Any`): Env observation. - - policy_output (:obj:`Dict[str, Any]`): The output of policy collect mode(``self._forward_collect``),\ - including at least ``action``. - - timestep (:obj:`namedtuple`): The output after env step(execute policy output action), including at \ - least ``obs``, ``reward``, ``done``, (here obs indicates obs after env step). - Returns: - - transition (:obj:`dict`): Dict type transition data. - """ - pass \ No newline at end of file diff --git a/ding/policy/dt.py b/ding/policy/dt.py index f5142053ca..2321280781 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -40,28 +40,20 @@ class DTPolicy(Policy): on_policy=False, # (bool) Whether use priority(priority sample, IS weight, update priority) priority=False, - # (float) Reward's future discount factor, aka. gamma. - discount_factor=0.97, # (int) N-step reward for target q_value estimation - nstep=1, obs_shape=4, action_shape=2, # encoder_hidden_size_list=[128, 128, 64], dataset='medium', # medium / medium-replay / medium-expert rtg_scale=1000, # normalize returns to go max_eval_ep_len=1000, # max len of one episode - num_eval_ep=10, # num of evaluation episodes batch_size=64, # training batch size wt_decay=1e-4, warmup_steps=10000, max_train_iters=200, context_len=20, - n_blocks=3, - embed_dim=128, - dropout_p=0.1, log_dir='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', learn=dict( - dataset_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO # batch_size=64, learning_rate=1e-4, @@ -93,12 +85,12 @@ def _init_learn(self) -> None: self.rtg_scale = self._cfg.rtg_scale # normalize returns to go self.rtg_target = self._cfg.rtg_target # max target reward_to_go self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode - self.num_eval_ep = self._cfg.num_eval_ep # num of evaluation episodes lr = self._cfg.learn.learning_rate # learning rate wt_decay = self._cfg.wt_decay # weight decay warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler - + + self.clip_grad_norm_p = self._cfg.clip_grad_norm_p self.context_len = self._cfg.context_len # K in decision transformer # # load data from this file @@ -111,7 +103,10 @@ def _init_learn(self) -> None: self.act_dim = self._cfg.model.act_dim self._learn_model = self._model - self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) + if self.cfg.env_type == 'atari': + self._optimizer = self._learn_model.configure_optimizers(self._cfg.weight_decay, lr, self._cfg.betas) + else: + self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) self._scheduler = torch.optim.lr_scheduler.LambdaLR( self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1) @@ -128,7 +123,8 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: Returns: - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. """ - + import time + st = time.time() self._learn_model.train() data = [[i[j] for i in data] for j in range(len(data[0]))] @@ -148,10 +144,13 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: # if discrete if not self._cfg.model.continuous and self.cfg.env_type != 'atari': actions = one_hot(actions.squeeze(-1), num=self.act_dim) - - state_preds, action_preds, return_preds = self._learn_model.forward( - timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go - ) + + if self.cfg.env_type == 'atari': + state_preds, action_preds, return_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1) + else: + state_preds, action_preds, return_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go) if self.cfg.env_type == 'atari': action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) @@ -170,7 +169,7 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: self._optimizer.zero_grad() action_loss.backward() - torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), 0.25) + torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p) self._optimizer.step() self._scheduler.step() @@ -198,50 +197,51 @@ def _init_eval(self) -> None: self.context_len = self._cfg.context_len # K in decision transformer self.t = [0 for _ in range(self.eval_batch_size)] - self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] - self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device) if not self._cfg.model.continuous: self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) else: self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) if self.cfg.env_type == 'atari': self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] else: + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) self.state_mean = torch.from_numpy(self._cfg.state_mean).to(self.device) self.state_std = torch.from_numpy(self._cfg.state_std).to(self.device) + self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device) self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: # save and forward data_id = list(data.keys()) - data_len = len(data_id) self._eval_model.eval() with torch.no_grad(): - timesteps = torch.zeros((data_len, 1, 1), dtype=torch.long, device=self.device) - if not self._cfg.model.continuous: - actions = torch.zeros((data_len, self.context_len, 1), dtype=torch.long, device=self.device) - else: - actions = torch.zeros((data_len, self.context_len, self.act_dim), dtype=torch.float32, device=self.device) if self.cfg.env_type == 'atari': - states = torch.zeros((data_len, self.context_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + states = torch.zeros((self.eval_batch_size, self.context_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self.device) else: - states = torch.zeros((data_len, self.context_len, self.state_dim), dtype=torch.float32, device=self.device) - rewards_to_go = torch.zeros((data_len, self.context_len, 1), dtype=torch.float32, device=self.device) + states = torch.zeros((self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self.device) + timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self.device) + if not self._cfg.model.continuous: + actions = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self.device) + else: + actions = torch.zeros((self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self.device) + rewards_to_go = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self.device) for i in data_id: if self.cfg.env_type == 'atari': - self.states[i, self.t[i]] = data[i]['obs'].to(self.device) / 255 + self.states[i, self.t[i]] = data[i]['obs'].to(self.device) else: self.states[i, self.t[i]] = (data[i]['obs'].to(self.device) - self.state_mean) / self.state_std # self.states[i, self.t[i]] = torch.tensor(data[i]['obs']) - self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'].to(self.device) / self.rtg_scale) - # self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'][0] / self.rtg_scale) + # self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'].to(self.device) / self.rtg_scale) + self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self.device) self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] if self.t[i] <= self.context_len: if self.cfg.env_type == 'atari': - timesteps[i] = self.t[i] * torch.ones((1), dtype=torch.int64).to(self.device) + timesteps[i] = self.t[i] * torch.ones((1, 1), dtype=torch.int64).to(self.device) else: timesteps[i] = self.timesteps[i, :self.context_len] states[i] = self.states[i, :self.context_len] @@ -249,7 +249,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] else: if self.cfg.env_type == 'atari': - timesteps[i] = self.t[i] * torch.ones((1), dtype=torch.int64).to(self.device) + timesteps[i] = self.t[i] * torch.ones((1, 1), dtype=torch.int64).to(self.device) else: timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] @@ -259,14 +259,24 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: # actions = one_hot(actions.squeeze(-1), num=self.act_dim) _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go) del timesteps, states, actions, rewards_to_go - act = torch.zeros((self.eval_batch_size, self.act_dim), dtype=torch.float32, device=self.device) - for i in data_id: - act[i] = act_preds[i, self.t[i]].detach() if self.t[i] < self.context_len else act_preds[i, -1].detach() - if not self._cfg.model.continuous: - act = torch.argmax(act, axis=1).unsqueeze(1) - for i in data_id: - self.actions[i, self.t[i]] = act[i] - self.t[i] += 1 + + if self.cfg.env_type == 'atari': + logits = act_preds[:, -1, :] + probs = F.softmax(logits, dim=-1) + act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self.device) + for i in data_id: + act[i] = torch.multinomial(probs[i], num_samples=1) + self.actions[i, self.t[i]] = act[i] + self.t[i] += 1 + else: + act = torch.zeros((self.eval_batch_size, self.act_dim), dtype=torch.float32, device=self.device) + for i in data_id: + act[i] = act_preds[i, self.t[i]].detach() if self.t[i] < self.context_len else act_preds[i, -1].detach() + if not self._cfg.model.continuous: + act = torch.argmax(act, axis=1).unsqueeze(1) + for i in data_id: + self.actions[i, self.t[i]] = act[i] + self.t[i] += 1 if self._cuda: act = to_device(act, 'cpu') output = {'action': act} @@ -276,7 +286,6 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: def _reset_eval(self, data_id: List[int] = None) -> None: # clean data if data_id is None: - self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] self.t = [0 for _ in range(self.eval_batch_size)] self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device) if not self._cfg.model.continuous: @@ -285,12 +294,14 @@ def _reset_eval(self, data_id: List[int] = None) -> None: self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) if self.cfg.env_type == 'atari': self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] else: self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) else: for i in data_id: - self.running_rtg[i] = self.rtg_target / self.rtg_scale self.t[i] = 0 self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self.device) if not self._cfg.model.continuous: @@ -299,8 +310,10 @@ def _reset_eval(self, data_id: List[int] = None) -> None: self.actions[i] = torch.zeros((self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) if self.cfg.env_type == 'atari': self.states[i] = torch.zeros((self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + self.running_rtg[i] = self.rtg_target else: self.states[i] = torch.zeros((self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + self.running_rtg[i] = self.rtg_target / self.rtg_scale self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) def get_d4rl_normalized_score(self, score, env_name): diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 208c386703..e764b4e19f 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -522,12 +522,12 @@ def __init__(self, cfg: dict) -> None: while not done: states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i]) states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) - obss += [states] - actions += [ac[0]] - stepwise_returns += [ret[0]] + obss.append(states) + actions.append(ac[0]) + stepwise_returns.append(ret[0]) if terminal[0]: - done_idxs += [len(obss)] - returns += [0] + done_idxs.append(len(obss)) + returns.append(0) if trajectories_to_load == 0: done = True else: @@ -591,7 +591,7 @@ def __len__(self) -> int: if self.env_type != 'atari': return len(self.trajectories) else: - return len(self.obss) + return len(self.obss) - self.context_len * 3 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if self.env_type != 'atari': @@ -640,7 +640,7 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso torch.zeros(padding_len, dtype=torch.long)], dim=0 ) return timesteps, states, actions, returns_to_go, traj_mask - else: + else: # mean time cost less than 0.02s block_size = self.context_len done_idx = idx + block_size for i in self.done_idxs: @@ -654,7 +654,6 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1) traj_mask = torch.ones(self.context_len, dtype=torch.long) - return timesteps, states, actions, rtgs, traj_mask diff --git a/dizoo/atari/config/pong_dt_config.py b/dizoo/atari/config/pong_dt_config.py index 2bf3de48c9..1e43a62adc 100644 --- a/dizoo/atari/config/pong_dt_config.py +++ b/dizoo/atari/config/pong_dt_config.py @@ -3,6 +3,7 @@ hopper_dt_config = dict( exp_name='dt_log/atari/Pong/Pong_dt_seed0', + # exp_name='dt_log/atari/Pong/Pong_dt_seed0', env=dict( env_id='PongNoFrameskip-v4', norm_obs=dict(use_norm=False, ), @@ -14,56 +15,47 @@ stop_value=20, frame_stack=4, is_train=False, + episode_num=10000, # stop in breakout ), policy=dict( num_buffers=50, num_steps=500000, - data_dir_prefix='/mnt/nfs/luyd/d4rl_atari/Pong/', + # num_steps=500, + data_dir_prefix='/mnt/nfs/luyd/d4rl_atari/Pong', trajectories_per_buffer=10, env_type='atari', - stop_value=20, - state_mean=None, - state_std=None, - evaluator_env_num=8, + stop_value=105, cuda=True, env_name='PongNoFrameskip-v4', dataset_name='Pong', - rtg_target=20, # max target return to go - rtg_scale=10, + # rtg_target=20, # max target return to go + rtg_target=90, # max target return to go + # rtg_scale=10, max_eval_ep_len=10000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode wt_decay=1e-4, + clip_grad_norm_p=1.0 + betas = (0.9, 0.95), + weight_decay=0.1, # warmup_steps=100000, warmup_steps=10000, - num_updates_per_iter=100, context_len=30, - n_blocks=6, - embed_dim=128, - n_heads=8, - dropout_p=0.1, model=dict( state_dim=(4, 84, 84), - act_dim=6, - n_blocks=3, - h_dim=128, + # act_dim=6, + act_dim=4, n_embd=128, context_len=30, n_heads=8, n_layer=6, - drop_p=0.1, embd_pdrop=0.1, resid_pdrop = 0.1, attn_pdrop = 0.1, continuous=False, ), - discount_factor=0.999, - nstep=3, learn=dict( batch_size=128, learning_rate=6e-4, target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, ), collect=dict( data_type='d4rl_trajectory', @@ -71,7 +63,7 @@ data_path='/mnt/nfs/luyd/d4rl_atari/Pong', unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), + eval=dict(evaluator=dict(eval_freq=100, ), ), other=dict( eps=dict( type='exp', diff --git a/ding/example/dt_atari.py b/dizoo/atari/entry/atari_dt_main.py similarity index 100% rename from ding/example/dt_atari.py rename to dizoo/atari/entry/atari_dt_main.py diff --git a/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py b/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py index 1847f7b7d4..eebbf7d509 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py @@ -19,17 +19,11 @@ rtg_target=300, # max target reward_to_go rtg_scale=150, max_eval_ep_len=1000, # max len of one episode # TODO - num_eval_ep=10, # num of evaluation episodes wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, # TODO evaluator_env_num=8, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', + log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', model=dict( state_dim=8, act_dim=4, @@ -43,16 +37,14 @@ discount_factor=0.999, nstep=3, learn=dict( - dataset_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO + dataset_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO learning_rate=3e-4, batch_size=64, # training batch size target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, ), collect=dict( data_type='d4rl_trajectory', - data_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', + data_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', unroll_len=1, ), eval=dict(evaluator=dict(eval_freq=100, )), @@ -65,7 +57,7 @@ type='lunarlander', import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], ), - env_manager=dict(type='base'), + env_manager=dict(type='subprocess'), policy=dict(type='dt'), ) lunarlander_dt_create_config = EasyDict(lunarlander_dt_create_config) diff --git a/dizoo/d4rl/config/hopper_medium_dt_config.py b/dizoo/d4rl/config/hopper_medium_dt_config.py index 37e804da04..c49da734b7 100644 --- a/dizoo/d4rl/config/hopper_medium_dt_config.py +++ b/dizoo/d4rl/config/hopper_medium_dt_config.py @@ -30,10 +30,7 @@ warmup_steps=10000, num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, + clip_grad_norm_p=0.25, model=dict( state_dim=11, act_dim=3, diff --git a/ding/example/dt_mujoco.py b/dizoo/d4rl/entry/d4rl_dt_mujoco.py similarity index 100% rename from ding/example/dt_mujoco.py rename to dizoo/d4rl/entry/d4rl_dt_mujoco.py From 062520a819ca045d1832d64c43f1d533914d2efb Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 31 Jul 2023 14:00:32 +0800 Subject: [PATCH 07/25] Fix abs path --- ding/policy/dt.py | 4 ++-- dizoo/atari/config/pong_dt_config.py | 8 ++++---- .../lunarlander/offline_data/collect_dqn_data_config.py | 4 ++-- dizoo/d4rl/config/hopper_expert_dt_config.py | 2 +- dizoo/d4rl/config/hopper_medium_dt_config.py | 4 ++-- dizoo/d4rl/config/hopper_medium_expert_dt_config.py | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ding/policy/dt.py b/ding/policy/dt.py index 2321280781..e6f2f51d5d 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -52,9 +52,9 @@ class DTPolicy(Policy): warmup_steps=10000, max_train_iters=200, context_len=20, - log_dir='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', + log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', learn=dict( - dataset_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO + dataset_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO # batch_size=64, learning_rate=1e-4, # ============================================================== diff --git a/dizoo/atari/config/pong_dt_config.py b/dizoo/atari/config/pong_dt_config.py index 1e43a62adc..bb70e3ef4c 100644 --- a/dizoo/atari/config/pong_dt_config.py +++ b/dizoo/atari/config/pong_dt_config.py @@ -21,7 +21,7 @@ num_buffers=50, num_steps=500000, # num_steps=500, - data_dir_prefix='/mnt/nfs/luyd/d4rl_atari/Pong', + data_dir_prefix='d4rl_atari/Pong', trajectories_per_buffer=10, env_type='atari', stop_value=105, @@ -33,7 +33,7 @@ # rtg_scale=10, max_eval_ep_len=10000, # max lenght of one episode wt_decay=1e-4, - clip_grad_norm_p=1.0 + clip_grad_norm_p=1.0, betas = (0.9, 0.95), weight_decay=0.1, # warmup_steps=100000, @@ -59,8 +59,8 @@ ), collect=dict( data_type='d4rl_trajectory', - # data_path='/mnt/nfs/luyd/hopper_medium.hdf5', - data_path='/mnt/nfs/luyd/d4rl_atari/Pong', + # data_path='hopper_medium.hdf5', + data_path='d4rl_atari/Pong', unroll_len=1, ), eval=dict(evaluator=dict(eval_freq=100, ), ), diff --git a/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py b/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py index 20050b7340..26b35c189d 100644 --- a/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py +++ b/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py @@ -35,14 +35,14 @@ log_policy=True, hook=dict( load_ckpt_before_run='./ckpt_best.pth.tar', # TODO: syspath modeified in other place, have to use abs path. May be fix in next version. - # load_ckpt_before_run='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', + # load_ckpt_before_run='DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', log_show_after_iter=100, save_ckpt_after_iter=10000, save_ckpt_after_run=False, ), cfg_type='BaseLearnerDict', load_path='./ckpt_best.pth.tar', # TODO: same like last path. - # load_path='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', + # load_path='DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', ), update_per_collect=10, batch_size=64, diff --git a/dizoo/d4rl/config/hopper_expert_dt_config.py b/dizoo/d4rl/config/hopper_expert_dt_config.py index 7180ddc717..b84d763b30 100644 --- a/dizoo/d4rl/config/hopper_expert_dt_config.py +++ b/dizoo/d4rl/config/hopper_expert_dt_config.py @@ -52,7 +52,7 @@ ), collect=dict( data_type='d4rl_trajectory', - data_path='/mnt/nfs/luyd/hopper_expert.hdf5', + data_path='hopper_expert.hdf5', unroll_len=1, ), eval=dict(evaluator=dict(evalu_freq=100, ), ), diff --git a/dizoo/d4rl/config/hopper_medium_dt_config.py b/dizoo/d4rl/config/hopper_medium_dt_config.py index c49da734b7..3daae57d42 100644 --- a/dizoo/d4rl/config/hopper_medium_dt_config.py +++ b/dizoo/d4rl/config/hopper_medium_dt_config.py @@ -53,8 +53,8 @@ ), collect=dict( data_type='d4rl_trajectory', - # data_path='/mnt/nfs/luyd/hopper_medium.hdf5', - data_path='/mnt/nfs/luyd/d4rl/hopper_medium-v2.pkl', + # data_path='hopper_medium.hdf5', + data_path='d4rl/hopper_medium-v2.pkl', unroll_len=1, ), eval=dict(evaluator=dict(evalu_freq=100, ), ), diff --git a/dizoo/d4rl/config/hopper_medium_expert_dt_config.py b/dizoo/d4rl/config/hopper_medium_expert_dt_config.py index 869ea10d84..fa78fe73cc 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_dt_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_dt_config.py @@ -52,7 +52,7 @@ ), collect=dict( data_type='d4rl_trajectory', - data_path='/mnt/nfs/luyd/d4rl/hopper_medium_expert.hdf5', + data_path='d4rl/hopper_medium_expert.hdf5', unroll_len=1, ), eval=dict(evaluator=dict(evalu_freq=100, ), ), From edce163b5172bdd7c5aeb7da904917c155803b9a Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 7 Aug 2023 14:51:59 +0800 Subject: [PATCH 08/25] Accelerate DT train iter by replacing dataloader --- ding/envs/env_wrappers/env_wrappers.py | 7 +- .../middleware/functional/__init__.py | 2 +- .../middleware/functional/data_processor.py | 44 +++++- .../middleware/tests/test_data_processor.py | 4 +- ding/model/template/__init__.py | 2 +- ding/model/template/decision_transformer.py | 146 ------------------ ding/utils/data/dataloader.py | 5 + ding/utils/data/dataset.py | 32 ++-- 8 files changed, 70 insertions(+), 172 deletions(-) delete mode 100644 ding/model/template/decision_transformer.py diff --git a/ding/envs/env_wrappers/env_wrappers.py b/ding/envs/env_wrappers/env_wrappers.py index 1d5752dd95..f0442509bd 100644 --- a/ding/envs/env_wrappers/env_wrappers.py +++ b/ding/envs/env_wrappers/env_wrappers.py @@ -1192,7 +1192,12 @@ def __init__(self, env): super().__init__(env) def reset(self): - return {'obs':self.env.reset(), 'reward': [0]} + ret = {'obs':self.env.reset(), 'reward': np.array([0])} + self._observation_space = gym.spaces.Dict({ + 'obs': self.env.observation_space, + 'reward': gym.spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32)} + ) + return ret def step(self, action): obs, reward, done, info = self.env.step(action) diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index bf5e965cae..6e3eedce3e 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -1,5 +1,5 @@ from .trainer import trainer, multistep_trainer -from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \ +from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, offline_data_fetcher_from_mem, \ sqil_data_pusher, buffer_saver from .collector import inferencer, rolloutor, TransitionList from .evaluator import interaction_evaluator, interaction_evaluator_ttorch diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index f551cba7ad..1440150123 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -180,6 +180,45 @@ def _fetch(ctx: "OnlineRLContext"): return _fetch +def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: + + from threading import Thread + from queue import Queue + import time + stream = torch.cuda.Stream() + def producer(queue, dataset, batch_size, device): + torch.set_num_threads(8) + nonlocal stream + idx_iter = iter(range(len(dataset))) + with torch.cuda.stream(stream): + while True: + if queue.full(): + time.sleep(0.1) + else: + try: + start_idx = next(idx_iter) + except StopIteration: + del idx_iter + idx_iter = iter(range(len(dataset))) + start_idx = next(idx_iter) + data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx+batch_size)] + data = [[i[j] for i in data] for j in range(len(data[0]))] + data = [torch.stack(x).to(device) for x in data] + queue.put(data) + queue = Queue(maxsize=50) + producer_thread = Thread(target=producer, args=(queue, dataset, cfg.policy.batch_size, 'cuda:0' if cfg.policy.cuda else 'cpu'), name='cuda_fetcher_producer') + + producer_thread.start() + + def _fetch(ctx: "OfflineRLContext"): + nonlocal queue + while queue.empty(): + time.sleep(0.001) + ctx.train_data = queue.get() + + return _fetch + + def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable: """ Overview: @@ -212,12 +251,7 @@ def _fetch(ctx: "OfflineRLContext"): del dataloader dataloader = iter(DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)) ctx.train_data = next(dataloader) - # for i, data in enumerate(dataloader): - # ctx.train_data = data - # yield - # ctx.train_epoch += 1 # TODO apply data update (e.g. priority) in offline setting when necessary - return _fetch diff --git a/ding/framework/middleware/tests/test_data_processor.py b/ding/framework/middleware/tests/test_data_processor.py index 25e5fca3e2..d63d392943 100644 --- a/ding/framework/middleware/tests/test_data_processor.py +++ b/ding/framework/middleware/tests/test_data_processor.py @@ -153,7 +153,9 @@ def __len__(self): ctx.train_epoch = 0 data_tmp = [] - for i, _ in enumerate(offline_data_fetcher(cfg, MyDataset())(ctx)): + fetch = offline_data_fetcher(cfg, MyDataset()) + for i in range(num_batch): + fetch(ctx) assert i // num_batch == ctx.train_epoch data_tmp.extend(ctx.train_data) diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 411be9673a..435d7174a3 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -21,7 +21,7 @@ from .maqac import MAQAC, ContinuousMAQAC from .madqn import MADQN from .vae import VanillaVAE -from .dt import DecisionTransformer, DecisionTransformerA +from .dt import DecisionTransformer from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS from .bcq import BCQ from .edac import QACEnsemble diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py deleted file mode 100644 index 7efbf79890..0000000000 --- a/ding/model/template/decision_transformer.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -The code is transplanted from https://github.com/nikhilbarhate99/min-decision-transformer -""" - -from ding.utils import MODEL_REGISTRY -from typing import Tuple -from ding.torch_utils.network.transformer import Attention, MaskedCausalAttention -import torch -import torch.nn as nn - - -class BlockM(nn.Module): - def __init__(self, h_dim, max_T, n_heads, drop_p): - super().__init__() - self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) - self.mlp = nn.Sequential( - nn.Linear(h_dim, 4*h_dim), - nn.GELU(), - nn.Linear(4*h_dim, h_dim), - nn.Dropout(drop_p), - ) - self.ln1 = nn.LayerNorm(h_dim) - self.ln2 = nn.LayerNorm(h_dim) - - def forward(self, x): - # Attention -> LayerNorm -> MLP -> LayerNorm - x = x + self.attention(x) # residual - x = self.ln1(x) - x = x + self.mlp(x) # residual - x = self.ln2(x) - return x - - -class Block(nn.Module): - - def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: - super().__init__() - self.attention = Attention(h_dim, h_dim, h_dim, n_heads, nn.Dropout(drop_p)) - self.att_drop = nn.Dropout(drop_p) - self.mlp = nn.Sequential( - nn.Linear(h_dim, 4 * h_dim), - nn.GELU(), - nn.Linear(4 * h_dim, h_dim), - nn.Dropout(drop_p), - ) - self.ln1 = nn.LayerNorm(h_dim) - self.ln2 = nn.LayerNorm(h_dim) - - mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T) - # register buffer makes sure mask does not get updated - # during backpropagation - self.register_buffer('mask', mask) - - def forward(self, x: torch.Tensor): - # Attention -> LayerNorm -> MLP -> LayerNorm - - x = x + self.att_drop(self.attention(x, self.mask)) # residual - x = self.ln1(x) - x = x + self.mlp(x) # residual - x = self.ln2(x) - return x - - -@MODEL_REGISTRY.register('dt') -class DecisionTransformer(nn.Module): - - def __init__( - self, - state_dim: int, - act_dim: int, - n_blocks: int, - h_dim: int, - context_len: int, - n_heads: int, - drop_p: float, - max_timestep: int = 4096, - continuous: bool = True - ) -> None: - super().__init__() - self.continuous = continuous - self.state_dim = state_dim - self.act_dim = act_dim - self.h_dim = h_dim - - # transformer blocks - # we will serially arrange `return`, `state` and `action`, so here the input_seq_len is 3 * context_len - input_seq_len = 3 * context_len - blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] - self.transformer = nn.Sequential(*blocks) - - # projection heads (project to embedding) - self.embed_ln = nn.LayerNorm(h_dim) - self.embed_timestep = nn.Embedding(max_timestep, h_dim) - self.embed_rtg = torch.nn.Linear(1, h_dim) - self.embed_state = torch.nn.Linear(state_dim, h_dim) - - if self.continuous: - action_tanh = True # True for continuous actions - self.embed_action = torch.nn.Linear(act_dim, h_dim) - - else: - action_tanh = False # False for discrete actions - self.embed_action = torch.nn.Linear(act_dim, h_dim) - - # prediction heads - self.predict_rtg = torch.nn.Linear(h_dim, 1) - self.predict_state = torch.nn.Linear(h_dim, state_dim) - self.predict_action = nn.Sequential(*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if action_tanh else []))) - - def forward( - self, timesteps: torch.Tensor, states: torch.Tensor, actions: torch.Tensor, returns_to_go: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - B, T, _ = states.shape - - time_embeddings = self.embed_timestep(timesteps) # shape: (B,context_len/T,h_dim) - - # time embeddings are treated similar to positional embeddings - # shape: (B,context_len,h_dim) - state_embeddings = self.embed_state(states) + time_embeddings - action_embeddings = self.embed_action(actions) + time_embeddings - returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings - - # stack rtg, states and actions and reshape sequence as - # (r1, s1, a1, r2, s2, a2 ...) - # after stack shape: (B, 3, context_len/T, h_dim) - h = torch.stack((returns_embeddings, state_embeddings, action_embeddings), - dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) - - h = self.embed_ln(h) - - # transformer and prediction - h = self.transformer(h) - - # get h reshaped such that its size = (B x 3 x T x h_dim) and - # h[:, 0, t] is conditioned on r_0, s_0, a_0 ... r_t - # h[:, 1, t] is conditioned on r_0, s_0, a_0 ... r_t, s_t - # h[:, 2, t] is conditioned on r_0, s_0, a_0 ... r_t, s_t, a_t - h = h.reshape(B, T, 3, self.h_dim) - - # get predictions - return_preds = self.predict_rtg(h[..., 2, :]) # predict next rtg given r, s, a - state_preds = self.predict_state(h[..., 2, :]) # predict next state given r, s, a - action_preds = self.predict_action(h[..., 1, :]) # predict action given r, s - - return state_preds, action_preds, return_preds diff --git a/ding/utils/data/dataloader.py b/ding/utils/data/dataloader.py index cb81925fd3..bcc70aac63 100644 --- a/ding/utils/data/dataloader.py +++ b/ding/utils/data/dataloader.py @@ -161,7 +161,11 @@ def _get_data(self, p: tm.multiprocessing.connection, c: tm.multiprocessing.conn break if cmd == 'get_data': # Main worker asks for data. + import time + st = time.time() + print('in data get at', st) data = self.data_source(self.batch_size) + print('already get data at', time.time(), 'cost', time.time()-st) # ``data`` can be callable, e.g. a function to read data from file, therefore we can divide # this job to pieces, assign to every slave worker and accomplish jobs asynchronously. # But if we get a list of dicts, which means the data has already been processed and @@ -186,6 +190,7 @@ def _async_loop(self, p: tm.multiprocessing.connection, c: tm.multiprocessing.co - p (:obj:`tm.multiprocessing.connection`): Parent connection. - c (:obj:`tm.multiprocessing.connection`): Child connection. """ + torch.set_num_threads(1) p.close() # Close unused p, only use c while not self.end_flag: if self.num_workers > 1: diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index e764b4e19f..371631bc37 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -329,12 +329,12 @@ class D4RLTrajectoryDataset(Dataset): } def __init__(self, cfg: dict) -> None: - dataset_path = cfg.policy.collect.data_path - rtg_scale = cfg.policy.rtg_scale - self.context_len = cfg.policy.context_len - self.env_type = cfg.policy.env_type + dataset_path = cfg.dataset.data_dir_prefix + rtg_scale = cfg.dataset.rtg_scale + self.context_len = cfg.dataset.context_len + self.env_type = cfg.dataset.env_type - if 'hdf5' in dataset_path: + if 'hdf5' in dataset_path: # for mujoco env try: import h5py import collections @@ -501,12 +501,12 @@ def __init__(self, cfg: dict) -> None: transitions_per_buffer = np.zeros(50, dtype=int) num_trajectories = 0 - while len(obss) < cfg.policy.num_steps: - buffer_num = np.random.choice(np.arange(50 - cfg.policy.num_buffers, 50), 1)[0] + while len(obss) < cfg.dataset.num_steps: + buffer_num = np.random.choice(np.arange(50 - cfg.dataset.num_buffers, 50), 1)[0] i = transitions_per_buffer[buffer_num] print('loading from buffer %d which has %d already loaded' % (buffer_num, i)) frb = FixedReplayBuffer( - data_dir=cfg.policy.data_dir_prefix + '/1/replay_logs', + data_dir=cfg.dataset.data_dir_prefix + '/1/replay_logs', replay_suffix=buffer_num, observation_shape=(84, 84), stack_size=4, @@ -518,7 +518,7 @@ def __init__(self, cfg: dict) -> None: if frb._loaded_buffers: done = False curr_num_transitions = len(obss) - trajectories_to_load = cfg.policy.trajectories_per_buffer + trajectories_to_load = cfg.dataset.trajectories_per_buffer while not done: states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i]) states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) @@ -541,7 +541,7 @@ def __init__(self, cfg: dict) -> None: returns[-1] = 0 i = transitions_per_buffer[buffer_num] done = True - num_trajectories += (cfg.policy.trajectories_per_buffer - trajectories_to_load) + num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load) transitions_per_buffer[buffer_num] = i print('this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' % (i, len(obss), num_trajectories)) @@ -560,7 +560,6 @@ def __init__(self, cfg: dict) -> None: rtg_j = curr_traj_returns[j-start_index:i-start_index] rtg[j] = sum(rtg_j) start_index = i - print('max rtg is %d' % max(rtg)) # -- create timestep dataset start_index = 0 @@ -569,7 +568,6 @@ def __init__(self, cfg: dict) -> None: i = int(i) timesteps[start_index:i+1] = np.arange(i+1 - start_index) start_index = i+1 - print('max timestep is %d' % max(timesteps)) self.obss = obss self.actions = actions @@ -640,7 +638,7 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso torch.zeros(padding_len, dtype=torch.long)], dim=0 ) return timesteps, states, actions, returns_to_go, traj_mask - else: # mean time cost less than 0.02s + else: # mean cost less than 0.001s block_size = self.context_len done_idx = idx + block_size for i in self.done_idxs: @@ -648,11 +646,11 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso done_idx = min(int(i), done_idx) break idx = done_idx - block_size - states = torch.tensor(np.array(self.obss[idx:done_idx]), dtype=torch.float32).reshape(block_size, -1) # (block_size, 4*84*84) + states = torch.as_tensor(np.array(self.obss[idx:done_idx]), dtype=torch.float32).view(block_size, -1) # (block_size, 4*84*84) states = states / 255. - actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) - rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) - timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1) + actions = torch.as_tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) + rtgs = torch.as_tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) + timesteps = torch.as_tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1) traj_mask = torch.ones(self.context_len, dtype=torch.long) return timesteps, states, actions, rtgs, traj_mask From 0fe0c0222c7bffc8b63d095ee500c74902e1eb05 Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 7 Aug 2023 14:53:01 +0800 Subject: [PATCH 09/25] Simplify dt model and policy and config --- ding/model/template/dt.py | 249 +++--------------- ding/policy/dt.py | 208 +++++---------- .../{ => serial/pong}/pong_dt_config.py | 73 ++--- dizoo/atari/entry/atari_dt_main.py | 31 ++- .../config/halfcheetah_expert_dt_config.py | 54 ++-- .../config/halfcheetah_medium_dt_config.py | 54 ++-- .../halfcheetah_medium_expert_dt_config.py | 54 ++-- .../halfcheetah_medium_replay_dt_config.py | 54 ++-- dizoo/d4rl/config/hopper_expert_dt_config.py | 47 ++-- dizoo/d4rl/config/hopper_medium_dt_config.py | 41 +-- .../config/hopper_medium_expert_dt_config.py | 47 ++-- .../config/hopper_medium_replay_dt_config.py | 54 ++-- .../d4rl/config/walker2d_expert_dt_config.py | 74 +++--- .../d4rl/config/walker2d_medium_dt_config.py | 74 +++--- .../walker2d_medium_expert_dt_config.py | 74 +++--- .../walker2d_medium_replay_dt_config.py | 74 +++--- 16 files changed, 426 insertions(+), 836 deletions(-) rename dizoo/atari/config/{ => serial/pong}/pong_dt_config.py (51%) diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index 3b57d3e461..7b5a941ebe 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -92,7 +92,7 @@ def forward(self, x): class DecisionTransformer(nn.Module): def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len, - n_heads, drop_p, max_timestep=4096): + n_heads, drop_p, max_timestep=4096, state_encoder=None, continuous=False): super().__init__() self.state_dim = state_dim @@ -108,34 +108,45 @@ def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len, self.embed_ln = nn.LayerNorm(h_dim) self.embed_timestep = nn.Embedding(max_timestep, h_dim) self.embed_rtg = torch.nn.Linear(1, h_dim) - self.embed_state = torch.nn.Linear(state_dim, h_dim) - # # discrete actions - # self.embed_action = torch.nn.Embedding(act_dim, h_dim) - # use_action_tanh = False # False for discrete actions - - # continuous actions - self.embed_action = torch.nn.Linear(act_dim, h_dim) - use_action_tanh = True # True for continuous actions + self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) + self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep+1, self.h_dim)) + + if state_encoder == None: + self.embed_state = torch.nn.Linear(state_dim, h_dim) + self.predict_rtg = torch.nn.Linear(h_dim, 1) + self.predict_state = torch.nn.Linear(h_dim, state_dim) + else: + self.state_encoder = state_encoder + + if continuous: + # continuous actions + self.embed_action = torch.nn.Linear(act_dim, h_dim) + use_action_tanh = True # True for continuous actions + else: + # discrete actions + self.embed_action = torch.nn.Embedding(act_dim, h_dim) + use_action_tanh = False # False for discrete actions ### prediction heads - self.predict_rtg = torch.nn.Linear(h_dim, 1) - self.predict_state = torch.nn.Linear(h_dim, state_dim) self.predict_action = nn.Sequential( *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) ) - def forward(self, timesteps, states, actions, returns_to_go): - - B, T, _ = states.shape - - time_embeddings = self.embed_timestep(timesteps) - - # time embeddings are treated similar to positional embeddings - state_embeddings = self.embed_state(states) + time_embeddings - action_embeddings = self.embed_action(actions) + time_embeddings - returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings + B, T = states.shape[0], states.shape[1] + if self.state_encoder == None: + time_embeddings = self.embed_timestep(timesteps) + + # time embeddings are treated similar to positional embeddings + state_embeddings = self.embed_state(states) + time_embeddings + action_embeddings = self.embed_action(actions) + time_embeddings + returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings + else: + state_embeddings = self.state_encoder(states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()) # (batch * block_size, h_dim) + state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.h_dim) # (batch, block_size, h_dim) + returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32)) + action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim) # stack rtg, states and actions and reshape sequence as # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) @@ -143,7 +154,11 @@ def forward(self, timesteps, states, actions, returns_to_go): (returns_embeddings, state_embeddings, action_embeddings), dim=1 ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) - h = self.embed_ln(h) + all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, B, dim=0) # batch_size, traj_length, h_dim + position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1)) + position_embeddings = position_embeddings + self.pos_emb[:, :h.shape[1], :] + + h = self.embed_ln(h + position_embeddings) # transformer and prediction h = self.transformer(h) @@ -158,190 +173,12 @@ def forward(self, timesteps, states, actions, returns_to_go): h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) # get predictions - return_preds = self.predict_rtg(h[:,2]) # predict next rtg given r, s, a - state_preds = self.predict_state(h[:,2]) # predict next state given r, s, a + if self.state_encoder == None: + return_preds = self.predict_rtg(h[:,2]) # predict next rtg given r, s, a + state_preds = self.predict_state(h[:,2]) # predict next state given r, s, a + else: + return_preds = None + state_preds = None action_preds = self.predict_action(h[:,1]) # predict action given r, s return state_preds, action_preds, return_preds - - -class GELU(nn.Module): - def forward(self, input): - return F.gelu(input) - - -class CausalSelfAttention(nn.Module): - """ - A vanilla multi-head masked self-attention layer with a projection at the end. - It is possible to use torch.nn.MultiheadAttention here but I am including an - explicit implementation here to show that there is nothing too scary here. - """ - - def __init__(self, n_head, block_size, n_embd, attn_pdrop, resid_pdrop): - super().__init__() - assert n_embd % n_head == 0 - # key, query, value projections for all heads - self.key = nn.Linear(n_embd, n_embd) - self.query = nn.Linear(n_embd, n_embd) - self.value = nn.Linear(n_embd, n_embd) - # regularization - self.attn_drop = nn.Dropout(attn_pdrop) - self.resid_drop = nn.Dropout(resid_pdrop) - # output projection - self.proj = nn.Linear(n_embd, n_embd) - # causal mask to ensure that attention is only applied to the left in the input sequence - # self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) - self.register_buffer("mask", torch.tril(torch.ones(block_size + 1, block_size + 1)).view(1, 1, block_size + 1, block_size + 1)) - self.n_head = n_head - - def forward(self, x, layer_past=None): - B, T, C = x.size() - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - att = self.attn_drop(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side - - # output projection - y = self.resid_drop(self.proj(y)) - return y - - -class BlockA(nn.Module): - """ an unassuming Transformer block """ - - def __init__(self, n_head, block_size, n_embd, attn_pdrop, resid_pdrop): - super().__init__() - self.ln1 = nn.LayerNorm(n_embd) - self.ln2 = nn.LayerNorm(n_embd) - self.attn = CausalSelfAttention(n_head, block_size, n_embd, attn_pdrop, resid_pdrop) - self.mlp = nn.Sequential( - nn.Linear(n_embd, 4 * n_embd), - GELU(), - nn.Linear(4 * n_embd, n_embd), - nn.Dropout(resid_pdrop), - ) - - def forward(self, x): - x = x + self.attn(self.ln1(x)) - x = x + self.mlp(self.ln2(x)) - return x - - -class DecisionTransformerA(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.n_embd = config.n_embd - self.block_size = config.context_len * 3 - - # input embedding stem - self.tok_emb = nn.Embedding(config.act_dim, config.n_embd) - # self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) - self.pos_emb = nn.Parameter(torch.zeros(1, self.block_size + 1, config.n_embd)) - self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep+1, config.n_embd)) - self.drop = nn.Dropout(config.embd_pdrop) - - # transformer - self.blocks = nn.Sequential(*[BlockA(config.n_heads, self.block_size, config.n_embd, config.attn_pdrop, config.resid_pdrop) for _ in range(config.n_layer)]) - # decoder head - self.ln_f = nn.LayerNorm(config.n_embd) - self.head = nn.Linear(config.n_embd, config.act_dim, bias=False) - - self.state_encoder = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), - nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), - nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), - nn.Flatten(), nn.Linear(3136, config.n_embd), nn.Tanh()) - - self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh()) - - self.action_embeddings = nn.Sequential(nn.Embedding(config.act_dim, config.n_embd), nn.Tanh()) - - - # state, action, and return - def forward(self, timesteps, states, actions, returns_to_go, tar=None): - # states: (batch, block_size, 4*84*84) - # actions: (batch, block_size, 1) - # rtgs: (batch, block_size, 1) - # timesteps: (batch, 1, 1) - rtgs = returns_to_go - state_embeddings = self.state_encoder(states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()) # (batch * block_size, n_embd) - state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd) # (batch, block_size, n_embd) - - rtg_embeddings = self.ret_emb(rtgs.type(torch.float32)) - action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd) - - token_embeddings = torch.zeros((states.shape[0], states.shape[1]*3 - int(tar is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) - token_embeddings[:,::3,:] = rtg_embeddings - token_embeddings[:,1::3,:] = state_embeddings - token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] + int(tar is None):,:] - - batch_size = states.shape[0] - all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd - - position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :] - - x = self.drop(token_embeddings + position_embeddings) - x = self.blocks(x) - x = self.ln_f(x) - logits = self.head(x) - - logits = logits[:, 1::3, :] # only keep predictions from state_embeddings - - return None, logits, None - - def configure_optimizers(self, weight_decay, learning_rate, betas): - """ - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. - """ - - # separate out all parameters to those that will and won't experience regularizing weight decay - decay = set() - no_decay = set() - # whitelist_weight_modules = (torch.nn.Linear, ) - whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) - blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) - for mn, m in self.named_modules(): - for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name - - if pn.endswith('bias'): - # all biases will not be decayed - no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): - # weights of whitelist modules will be weight decayed - decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): - # weights of blacklist modules will NOT be weight decayed - no_decay.add(fpn) - - # special case the position embedding parameter in the root GPT module as not decayed - no_decay.add('pos_emb') - no_decay.add('global_pos_emb') - - # validate that we considered every parameter - param_dict = {pn: p for pn, p in self.named_parameters()} - inter_params = decay & no_decay - union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) - - # create the pytorch optimizer object - optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, - ] - optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) - return optimizer \ No newline at end of file diff --git a/ding/policy/dt.py b/ding/policy/dt.py index e6f2f51d5d..97a78a2be1 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -1,26 +1,14 @@ """The code is adapted from https://github.com/nikhilbarhate99/min-decision-transformer """ -from typing import List, Dict, Any, Tuple, Union +from typing import List, Dict, Any, Tuple from collections import namedtuple -from torch.distributions import Normal, Independent -from ding.torch_utils import Adam, to_device -from ditk import logging -from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ - qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data -from ding.model import model_wrap -from ding.utils.data.dataset import D4RLTrajectoryDataset -from ding.utils import POLICY_REGISTRY -from ding.utils.data import default_collate, default_decollate -from datetime import datetime -from ding.torch_utils import one_hot -import numpy as np import torch.nn.functional as F import torch -import gym -import copy -import os -import csv +from ding.torch_utils import to_device +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_decollate +from ding.torch_utils import one_hot from .base_policy import Policy @@ -43,33 +31,17 @@ class DTPolicy(Policy): # (int) N-step reward for target q_value estimation obs_shape=4, action_shape=2, - # encoder_hidden_size_list=[128, 128, 64], - dataset='medium', # medium / medium-replay / medium-expert rtg_scale=1000, # normalize returns to go max_eval_ep_len=1000, # max len of one episode batch_size=64, # training batch size - wt_decay=1e-4, - warmup_steps=10000, - max_train_iters=200, - context_len=20, - log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', - learn=dict( - dataset_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO - # batch_size=64, - learning_rate=1e-4, - # ============================================================== - # The following configs are algorithm-specific - # ============================================================== - ), - # collect_mode config - collect=dict(), - eval=dict(), - # other config - other=dict(), + wt_decay=1e-4, # decay weight in optimizer + warmup_steps=10000, # steps for learning rate warmup + context_len=20, # length of transformer input + learning_rate=1e-4, ) def default_model(self) -> Tuple[str, List[str]]: - return 'dt', ['ding.model.template.decision_transformer'] + return 'dt', ['ding.model.template.dt'] def _init_learn(self) -> None: r""" @@ -86,27 +58,18 @@ def _init_learn(self) -> None: self.rtg_target = self._cfg.rtg_target # max target reward_to_go self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode - lr = self._cfg.learn.learning_rate # learning rate + lr = self._cfg.learning_rate # learning rate wt_decay = self._cfg.wt_decay # weight decay warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler self.clip_grad_norm_p = self._cfg.clip_grad_norm_p - self.context_len = self._cfg.context_len # K in decision transformer - - # # load data from this file - # dataset_path = f'{self._cfg.dataset_dir}/{env_d4rl_name}.pkl' - - # training and evaluation device - self.device = torch.device(self._device) + self.context_len = self._cfg.model.context_len # K in decision transformer self.state_dim = self._cfg.model.state_dim self.act_dim = self._cfg.model.act_dim self._learn_model = self._model - if self.cfg.env_type == 'atari': - self._optimizer = self._learn_model.configure_optimizers(self._cfg.weight_decay, lr, self._cfg.betas) - else: - self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) + self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) self._scheduler = torch.optim.lr_scheduler.LambdaLR( self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1) @@ -123,18 +86,9 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: Returns: - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. """ - import time - st = time.time() - self._learn_model.train() - - data = [[i[j] for i in data] for j in range(len(data[0]))] + timesteps, states, actions, returns_to_go, traj_mask = data - timesteps = torch.stack(timesteps).to(self.device) # B x T - states = torch.stack(states).to(self.device) # B x T x state_dim - actions = torch.stack(actions).to(self.device) # B x T x act_dim - returns_to_go = torch.stack(returns_to_go).to(self.device) # B x T x 1 - traj_mask = torch.stack(traj_mask).to(self.device) # B x T - action_target = torch.clone(actions).detach().to(self.device) + action_target = torch.clone(actions).detach().to(self._device) # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), # and we need a 3-dim tensor @@ -142,17 +96,13 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: returns_to_go = returns_to_go.unsqueeze(-1) # if discrete - if not self._cfg.model.continuous and self.cfg.env_type != 'atari': + if not self._cfg.model.continuous and 'state_mean' in self._cfg: actions = one_hot(actions.squeeze(-1), num=self.act_dim) - if self.cfg.env_type == 'atari': - state_preds, action_preds, return_preds = self._learn_model.forward( - timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1) - else: - state_preds, action_preds, return_preds = self._learn_model.forward( + state_preds, action_preds, return_preds = self._learn_model.forward( timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go) - if self.cfg.env_type == 'atari': + if 'state_mean' not in self._cfg: action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) else: traj_mask = traj_mask.view(-1, ) @@ -187,30 +137,30 @@ def _init_eval(self) -> None: self._eval_model = self._model # self._eval_model.reset() # init data - self.device = torch.device(self._device) + self._device = torch.device(self._device) self.rtg_scale = self._cfg.rtg_scale # normalize returns to go self.rtg_target = self._cfg.rtg_target # max target reward_to_go self.state_dim = self._cfg.model.state_dim self.act_dim = self._cfg.model.act_dim self.eval_batch_size = self._cfg.evaluator_env_num self.max_eval_ep_len = self._cfg.max_eval_ep_len - self.context_len = self._cfg.context_len # K in decision transformer + self.context_len = self._cfg.model.context_len # K in decision transformer self.t = [0 for _ in range(self.eval_batch_size)] - if not self._cfg.model.continuous: - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) + if self._cfg.model.continuous: + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device) else: - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) - if self.cfg.env_type == 'atari': - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) + if 'state_mean' not in self._cfg: + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self._device) self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] else: self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) - self.state_mean = torch.from_numpy(self._cfg.state_mean).to(self.device) - self.state_std = torch.from_numpy(self._cfg.state_std).to(self.device) - self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device) - self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device) + self.state_mean = torch.from_numpy(self._cfg.state_mean).to(self._device) + self.state_std = torch.from_numpy(self._cfg.state_std).to(self._device) + self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self._device) + self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: # save and forward @@ -218,38 +168,36 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: self._eval_model.eval() with torch.no_grad(): - if self.cfg.env_type == 'atari': - states = torch.zeros((self.eval_batch_size, self.context_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) - timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self.device) + if 'state_mean' not in self._cfg: + states = torch.zeros((self.eval_batch_size, self.context_len,) + tuple(self.state_dim), dtype=torch.float32, device=self._device) + timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device) else: - states = torch.zeros((self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self.device) - timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self.device) + states = torch.zeros((self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device) + timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device) if not self._cfg.model.continuous: - actions = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self.device) + actions = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device) else: - actions = torch.zeros((self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self.device) - rewards_to_go = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self.device) + actions = torch.zeros((self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device) + rewards_to_go = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device) for i in data_id: - if self.cfg.env_type == 'atari': - self.states[i, self.t[i]] = data[i]['obs'].to(self.device) + if 'state_mean' not in self._cfg: + self.states[i, self.t[i]] = data[i]['obs'].to(self._device) else: - self.states[i, self.t[i]] = (data[i]['obs'].to(self.device) - self.state_mean) / self.state_std - # self.states[i, self.t[i]] = torch.tensor(data[i]['obs']) - # self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'].to(self.device) / self.rtg_scale) - self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self.device) + self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std + self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device) self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] if self.t[i] <= self.context_len: - if self.cfg.env_type == 'atari': - timesteps[i] = self.t[i] * torch.ones((1, 1), dtype=torch.int64).to(self.device) + if 'state_mean' not in self._cfg: + timesteps[i] = min(self.t[i], self._cfg.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device) else: timesteps[i] = self.timesteps[i, :self.context_len] states[i] = self.states[i, :self.context_len] actions[i] = self.actions[i, :self.context_len] rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] else: - if self.cfg.env_type == 'atari': - timesteps[i] = self.t[i] * torch.ones((1, 1), dtype=torch.int64).to(self.device) + if 'state_mean' not in self._cfg: + timesteps[i] = min(self.t[i], self._cfg.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device) else: timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] @@ -259,24 +207,14 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: # actions = one_hot(actions.squeeze(-1), num=self.act_dim) _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go) del timesteps, states, actions, rewards_to_go - - if self.cfg.env_type == 'atari': - logits = act_preds[:, -1, :] - probs = F.softmax(logits, dim=-1) - act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self.device) - for i in data_id: - act[i] = torch.multinomial(probs[i], num_samples=1) - self.actions[i, self.t[i]] = act[i] - self.t[i] += 1 - else: - act = torch.zeros((self.eval_batch_size, self.act_dim), dtype=torch.float32, device=self.device) - for i in data_id: - act[i] = act_preds[i, self.t[i]].detach() if self.t[i] < self.context_len else act_preds[i, -1].detach() - if not self._cfg.model.continuous: - act = torch.argmax(act, axis=1).unsqueeze(1) - for i in data_id: - self.actions[i, self.t[i]] = act[i] - self.t[i] += 1 + + logits = act_preds[:, -1, :] + if not self._cfg.model.continuous: + act = torch.argmax(logits, axis=1).unsqueeze(1) + for i in data_id: + self.actions[i, self.t[i]] = act[i] + self.t[i] += 1 + if self._cuda: act = to_device(act, 'cpu') output = {'action': act} @@ -287,41 +225,34 @@ def _reset_eval(self, data_id: List[int] = None) -> None: # clean data if data_id is None: self.t = [0 for _ in range(self.eval_batch_size)] - self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device) + self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self._device) if not self._cfg.model.continuous: - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) else: - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) - if self.cfg.env_type == 'atari': - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device) + if 'state_mean' not in self._cfg: + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self._device) self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] else: - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device) self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] - self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) + self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) else: for i in data_id: self.t[i] = 0 - self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self.device) if not self._cfg.model.continuous: - self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self.device) + self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) else: - self.actions[i] = torch.zeros((self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self.device) - if self.cfg.env_type == 'atari': - self.states[i] = torch.zeros((self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) + self.actions[i] = torch.zeros((self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device) + if 'state_mean' not in self._cfg: + self.states[i] = torch.zeros((self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self._device) self.running_rtg[i] = self.rtg_target else: - self.states[i] = torch.zeros((self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self.device) + self.states[i] = torch.zeros((self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device) self.running_rtg[i] = self.rtg_target / self.rtg_scale - self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self.device) - - def get_d4rl_normalized_score(self, score, env_name): - env_key = env_name.split('-')[0].lower() - assert env_key in D4RLTrajectoryDataset.REF_MAX_SCORE, \ - f'no reference score for {env_key} env to calculate d4rl score' - d4rl_max_score, d4rl_min_score = D4RLTrajectoryDataset.REF_MAX_SCORE, D4RLTrajectoryDataset.REF_MIN_SCORE - return (score - d4rl_min_score[env_key]) / (d4rl_max_score[env_key] - d4rl_min_score[env_key]) + self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device) + self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) def _state_dict_learn(self) -> Dict[str, Any]: return { @@ -332,13 +263,10 @@ def _state_dict_learn(self) -> Dict[str, Any]: def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._learn_model.load_state_dict(state_dict['model']) - # self._target_model.load_state_dict(state_dict['target_model']) self._optimizer.load_state_dict(state_dict['optimizer']) def _load_state_dict_eval(self, state_dict: Dict[str, Any]) -> None: self._eval_model.load_state_dict(state_dict) - # self._target_model.load_state_dict(state_dict['target_model']) - # self._optimizer.load_state_dict(state_dict['optimizer']) def _monitor_vars_learn(self) -> List[str]: return ['cur_lr', 'action_loss'] diff --git a/dizoo/atari/config/pong_dt_config.py b/dizoo/atari/config/serial/pong/pong_dt_config.py similarity index 51% rename from dizoo/atari/config/pong_dt_config.py rename to dizoo/atari/config/serial/pong/pong_dt_config.py index bb70e3ef4c..9d9387036d 100644 --- a/dizoo/atari/config/pong_dt_config.py +++ b/dizoo/atari/config/serial/pong/pong_dt_config.py @@ -1,13 +1,10 @@ from easydict import EasyDict from copy import deepcopy -hopper_dt_config = dict( +Pong_dt_config = dict( exp_name='dt_log/atari/Pong/Pong_dt_seed0', - # exp_name='dt_log/atari/Pong/Pong_dt_seed0', env=dict( env_id='PongNoFrameskip-v4', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), collector_env_num=1, evaluator_env_num=8, use_act_scale=True, @@ -17,68 +14,50 @@ is_train=False, episode_num=10000, # stop in breakout ), - policy=dict( + dataset=dict( + env_type='atari', + # num_steps=500000, + num_steps=500, num_buffers=50, - num_steps=500000, - # num_steps=500, - data_dir_prefix='d4rl_atari/Pong', + rtg_scale=None, + context_len=30, + data_dir_prefix='/mnt/nfs/luyd/d4rl_atari/Pong', trajectories_per_buffer=10, - env_type='atari', - stop_value=105, + ), + policy=dict( cuda=True, + stop_value=20, + evaluator_env_num=8, env_name='PongNoFrameskip-v4', - dataset_name='Pong', - # rtg_target=20, # max target return to go - rtg_target=90, # max target return to go - # rtg_scale=10, + rtg_target=20, # max target return to go max_eval_ep_len=10000, # max lenght of one episode wt_decay=1e-4, clip_grad_norm_p=1.0, - betas = (0.9, 0.95), weight_decay=0.1, - # warmup_steps=100000, warmup_steps=10000, - context_len=30, model=dict( state_dim=(4, 84, 84), - # act_dim=6, - act_dim=4, - n_embd=128, + act_dim=6, + n_blocks=6, + h_dim=128, context_len=30, n_heads=8, - n_layer=6, - embd_pdrop=0.1, - resid_pdrop = 0.1, - attn_pdrop = 0.1, + drop_p=0.1, continuous=False, - ), - learn=dict( - batch_size=128, - learning_rate=6e-4, - target_update_freq=100, - ), + ), + batch_size=128, + learning_rate=6e-4, + eval=dict(evaluator=dict(eval_freq=100, ), ), collect=dict( data_type='d4rl_trajectory', - # data_path='hopper_medium.hdf5', - data_path='d4rl_atari/Pong', unroll_len=1, ), - eval=dict(evaluator=dict(eval_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), - ), ), ) -hopper_dt_config = EasyDict(hopper_dt_config) -main_config = hopper_dt_config -hopper_dt_create_config = dict( +Pong_dt_config = EasyDict(Pong_dt_config) +main_config = Pong_dt_config +Pong_dt_create_config = dict( env=dict( type='atari', import_names=['dizoo.atari.envs.atari_env'], @@ -86,8 +65,8 @@ env_manager=dict(type='subprocess'), policy=dict(type='dt'), ) -hopper_dt_create_config = EasyDict(hopper_dt_create_config) -create_config = hopper_dt_create_config +Pong_dt_create_config = EasyDict(Pong_dt_create_config) +create_config = Pong_dt_create_config if __name__ == "__main__": from ding.entry import serial_pipeline_dt diff --git a/dizoo/atari/entry/atari_dt_main.py b/dizoo/atari/entry/atari_dt_main.py index 88e66f9eb1..ccf919488b 100644 --- a/dizoo/atari/entry/atari_dt_main.py +++ b/dizoo/atari/entry/atari_dt_main.py @@ -1,42 +1,50 @@ import gym import torch import numpy as np +import torch.nn as nn from ditk import logging -from ding.model.template.dt import DecisionTransformer, DecisionTransformerA +from ding.model.template.dt import DecisionTransformer from ding.policy import DTPolicy -from ding.envs import BaseEnvManagerV2 +from ding.envs import BaseEnvManagerV2, SyncSubprocessEnvManager, SubprocessEnvManagerV2 from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper from ding.data import create_dataset from ding.config import compile_config from ding.framework import task, ding_init from ding.framework.context import OfflineRLContext -from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger, termination_checker +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger, termination_checker, offline_data_fetcher_from_mem from ding.utils import set_pkg_seed from dizoo.atari.envs import AtariEnv -from dizoo.atari.config.pong_dt_config import main_config, create_config - +from dizoo.atari.config.serial.pong.pong_dt_config import main_config, create_config +import os +from functools import partial +os.environ['CUDA_LAUNCH_BLOCKING'] = "1" def main(): # If you don't have offline data, you need to prepare if first and set the data_path in config # For demostration, we also can train a RL policy (e.g. SAC) and collect some data logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True) - # ding_init(cfg) + ding_init(cfg) with task.start(async_mode=False, ctx=OfflineRLContext()): - evaluator_env = BaseEnvManagerV2( + evaluator_env = SubprocessEnvManagerV2( env_fn=[lambda: AllinObsWrapper(AtariEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) dataset = create_dataset(cfg) - cfg.policy.model.max_timestep = dataset.get_max_timestep() - # model = DecisionTransformer(**cfg.policy.model) - model = DecisionTransformerA(cfg.policy.model) + cfg.policy.max_timestep = dataset.get_max_timestep() + # dataset = get_data_source(dataset) + state_encoder = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), + nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), + nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), + nn.Flatten(), nn.Linear(3136, cfg.policy.model.h_dim), nn.Tanh()) + + model = DecisionTransformer(**cfg.policy.model, state_encoder=state_encoder) policy = DTPolicy(cfg.policy, model=model) task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) - task.use(offline_data_fetcher(cfg, dataset)) + task.use(offline_data_fetcher_from_mem(cfg, dataset)) task.use(trainer(cfg, policy.learn_mode)) task.use(termination_checker(max_train_iter=1e5)) task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) @@ -46,4 +54,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/dizoo/d4rl/config/halfcheetah_expert_dt_config.py b/dizoo/d4rl/config/halfcheetah_expert_dt_config.py index 6c8e219bc0..617d17bc73 100644 --- a/dizoo/d4rl/config/halfcheetah_expert_dt_config.py +++ b/dizoo/d4rl/config/halfcheetah_expert_dt_config.py @@ -2,37 +2,38 @@ from copy import deepcopy halfcheetah_dt_config = dict( - exp_name='halfcheetah_expert_dt_seed0', + exp_name='dt_log/d4rl/halfcheetah/halfcheetah_expert_dt_seed0', env=dict( env_id='HalfCheetah-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, stop_value=6000, ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/halfcheetah_expert-v2.pkl', + ), policy=dict( - stop_value=6000, cuda=True, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, env_name='HalfCheetah-v3', rtg_target=6000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/halfcheetah_expert_dt_log', + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( - state_dim=17, - act_dim=6, + state_dim=11, + act_dim=3, n_blocks=3, h_dim=128, context_len=20, @@ -40,26 +41,13 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/halfcheetah-expert-v2.pkl', - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), + batch_size=64, + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) diff --git a/dizoo/d4rl/config/halfcheetah_medium_dt_config.py b/dizoo/d4rl/config/halfcheetah_medium_dt_config.py index f61ca5e0c2..7521af6dd5 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_dt_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_dt_config.py @@ -2,37 +2,38 @@ from copy import deepcopy halfcheetah_dt_config = dict( - exp_name='halfcheetah_medium_dt_seed0', + exp_name='dt_log/d4rl/halfcheetah/halfcheetah_medium_dt_seed0', env=dict( env_id='HalfCheetah-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, stop_value=6000, ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/halfcheetah_medium-v2.pkl', + ), policy=dict( - stop_value=6000, cuda=True, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, env_name='HalfCheetah-v3', rtg_target=6000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/halfcheetah_medium_dt_log', + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( - state_dim=17, - act_dim=6, + state_dim=11, + act_dim=3, n_blocks=3, h_dim=128, context_len=20, @@ -40,26 +41,13 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/halfcheetah-medium-v2.pkl', - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), + batch_size=64, + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_dt_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_dt_config.py index 9d5d5c6aa2..1f9c636d20 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_expert_dt_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_dt_config.py @@ -2,37 +2,38 @@ from copy import deepcopy halfcheetah_dt_config = dict( - exp_name='halfcheetah_medium_expert_dt_seed0', + exp_name='dt_log/d4rl/halfcheetah/halfcheetah_medium_expert_dt_seed0', env=dict( env_id='HalfCheetah-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, stop_value=6000, ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/halfcheetah_medium_expert-v2.pkl', + ), policy=dict( - stop_value=6000, cuda=True, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, env_name='HalfCheetah-v3', rtg_target=6000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/halfcheetah_medium_expert_dt_log', + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( - state_dim=17, - act_dim=6, + state_dim=11, + act_dim=3, n_blocks=3, h_dim=128, context_len=20, @@ -40,26 +41,13 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/halfcheetah-medium-expert-v2.pkl', - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), + batch_size=64, + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) diff --git a/dizoo/d4rl/config/halfcheetah_medium_replay_dt_config.py b/dizoo/d4rl/config/halfcheetah_medium_replay_dt_config.py index 98260c4dd3..aa07e22280 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_replay_dt_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_replay_dt_config.py @@ -2,37 +2,38 @@ from copy import deepcopy halfcheetah_dt_config = dict( - exp_name='halfcheetah_medium_replay_dt_seed0', + exp_name='dt_log/d4rl/halfcheetah/halfcheetah_medium_replay_dt_seed0', env=dict( env_id='HalfCheetah-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, stop_value=6000, ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/halfcheetah_medium_replay-v2.pkl', + ), policy=dict( - stop_value=6000, cuda=True, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, env_name='HalfCheetah-v3', rtg_target=6000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/halfcheetah_medium_replay_dt_log', + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( - state_dim=17, - act_dim=6, + state_dim=11, + act_dim=3, n_blocks=3, h_dim=128, context_len=20, @@ -40,26 +41,13 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/halfcheetah-medium-replay-v2.pkl', - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), + batch_size=64, + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) diff --git a/dizoo/d4rl/config/hopper_expert_dt_config.py b/dizoo/d4rl/config/hopper_expert_dt_config.py index b84d763b30..11ee61f473 100644 --- a/dizoo/d4rl/config/hopper_expert_dt_config.py +++ b/dizoo/d4rl/config/hopper_expert_dt_config.py @@ -5,32 +5,32 @@ exp_name='dt_log/d4rl/hopper/hopper_expert_dt_seed0', env=dict( env_id='Hopper-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, - stop_value=6000, + stop_value=3600, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/hopper_expert-v2.pkl', ), policy=dict( - stop_value=6000, + cuda=True, + stop_value=3600, state_mean=None, state_std=None, evaluator_env_num=8, - cuda=True, env_name='Hopper-v3', - rtg_target=6000, # max target return to go + rtg_target=3600, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( state_dim=11, act_dim=3, @@ -41,30 +41,13 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - batch_size=64, - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), + batch_size=64, + learning_rate=1e-4, collect=dict( data_type='d4rl_trajectory', - data_path='hopper_expert.hdf5', unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), - ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) diff --git a/dizoo/d4rl/config/hopper_medium_dt_config.py b/dizoo/d4rl/config/hopper_medium_dt_config.py index 3daae57d42..a9ce705529 100644 --- a/dizoo/d4rl/config/hopper_medium_dt_config.py +++ b/dizoo/d4rl/config/hopper_medium_dt_config.py @@ -5,31 +5,31 @@ exp_name='dt_log/d4rl/hopper/hopper_medium_dt_seed0', env=dict( env_id='Hopper-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, stop_value=3600, ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/hopper_medium-v2.pkl', + ), policy=dict( + cuda=True, stop_value=3600, state_mean=None, state_std=None, evaluator_env_num=8, - cuda=True, env_name='Hopper-v3', - dataset_name='hopper-medium-v2', rtg_target=3600, # max target return to go - rtg_scale=1000, max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode wt_decay=1e-4, - # warmup_steps=100000, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, + weight_decay=0.1, clip_grad_norm_p=0.25, model=dict( state_dim=11, @@ -39,34 +39,15 @@ context_len=20, n_heads=1, drop_p=0.1, - max_timestep=0, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - batch_size=64, - learning_rate=1e-4, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), + batch_size=64, + learning_rate=1e-4, collect=dict( data_type='d4rl_trajectory', - # data_path='hopper_medium.hdf5', - data_path='d4rl/hopper_medium-v2.pkl', unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), - ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) diff --git a/dizoo/d4rl/config/hopper_medium_expert_dt_config.py b/dizoo/d4rl/config/hopper_medium_expert_dt_config.py index fa78fe73cc..0592a89228 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_dt_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_dt_config.py @@ -5,32 +5,32 @@ exp_name='dt_log/d4rl/hopper/hopper_medium_expert_dt_seed0', env=dict( env_id='Hopper-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, - stop_value=6000, + stop_value=3600, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/hopper_medium_expert-v2.pkl', ), policy=dict( - stop_value=6000, + cuda=True, + stop_value=3600, state_mean=None, state_std=None, evaluator_env_num=8, - cuda=True, env_name='Hopper-v3', - rtg_target=6000, # max target return to go + rtg_target=3600, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( state_dim=11, act_dim=3, @@ -41,30 +41,13 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - batch_size=64, - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), + batch_size=64, + learning_rate=1e-4, collect=dict( data_type='d4rl_trajectory', - data_path='d4rl/hopper_medium_expert.hdf5', unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), - ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) diff --git a/dizoo/d4rl/config/hopper_medium_replay_dt_config.py b/dizoo/d4rl/config/hopper_medium_replay_dt_config.py index fd238c915e..a2615ba1b9 100644 --- a/dizoo/d4rl/config/hopper_medium_replay_dt_config.py +++ b/dizoo/d4rl/config/hopper_medium_replay_dt_config.py @@ -2,34 +2,35 @@ from copy import deepcopy hopper_dt_config = dict( - exp_name='hopper_medium_replay_dt_seed0', + exp_name='dt_log/d4rl/hopper/hopper_medium_replay_dt_seed0', env=dict( env_id='Hopper-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, - stop_value=6000, + stop_value=3600, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/hopper_medium_replay-v2.pkl', ), policy=dict( - stop_value=6000, cuda=True, + stop_value=3600, + state_mean=None, + state_std=None, + evaluator_env_num=8, env_name='Hopper-v3', - rtg_target=6000, # max target return to go + rtg_target=3600, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/hopper_medium_replay_dt_log', + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( state_dim=11, act_dim=3, @@ -40,26 +41,13 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/hopper-medium-replay-v2.pkl', - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), + batch_size=64, + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) diff --git a/dizoo/d4rl/config/walker2d_expert_dt_config.py b/dizoo/d4rl/config/walker2d_expert_dt_config.py index d50670dad7..3658f8ce03 100644 --- a/dizoo/d4rl/config/walker2d_expert_dt_config.py +++ b/dizoo/d4rl/config/walker2d_expert_dt_config.py @@ -1,38 +1,39 @@ from easydict import EasyDict from copy import deepcopy -walker2d_dt_config = dict( - exp_name='walker2d_expert_dt_seed0', +walk2d_dt_config = dict( + exp_name='dt_log/d4rl/walk2d/walk2d_expert_dt_seed0', env=dict( - env_id='Walker2d-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), + env_id='Walk2d-v3', collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, - stop_value=6000, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/walk2d_expert-v2.pkl', ), policy=dict( - stop_value=6000, cuda=True, - env_name='Walker2d-v3', - rtg_target=6000, # max target return to go + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walk2d-v3', + rtg_target=5000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/walker2d_expert_dt_log', + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( - state_dim=17, - act_dim=6, + state_dim=11, + act_dim=3, n_blocks=3, h_dim=128, context_len=20, @@ -40,32 +41,19 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/walker2d-expert-v2.pkl', - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), + batch_size=64, + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) -walker2d_dt_config = EasyDict(walker2d_dt_config) -main_config = walker2d_dt_config -walker2d_dt_create_config = dict( +walk2d_dt_config = EasyDict(walk2d_dt_config) +main_config = walk2d_dt_config +walk2d_dt_create_config = dict( env=dict( type='mujoco', import_names=['dizoo.mujoco.envs.mujoco_env'], @@ -73,8 +61,8 @@ env_manager=dict(type='subprocess'), policy=dict(type='dt'), ) -walker2d_dt_create_config = EasyDict(walker2d_dt_create_config) -create_config = walker2d_dt_create_config +walk2d_dt_create_config = EasyDict(walk2d_dt_create_config) +create_config = walk2d_dt_create_config if __name__ == "__main__": from ding.entry import serial_pipeline_dt diff --git a/dizoo/d4rl/config/walker2d_medium_dt_config.py b/dizoo/d4rl/config/walker2d_medium_dt_config.py index e3b741a129..57a93c0ab5 100644 --- a/dizoo/d4rl/config/walker2d_medium_dt_config.py +++ b/dizoo/d4rl/config/walker2d_medium_dt_config.py @@ -1,38 +1,39 @@ from easydict import EasyDict from copy import deepcopy -walker2d_dt_config = dict( - exp_name='walker2d_medium_dt_seed0', +walk2d_dt_config = dict( + exp_name='dt_log/d4rl/walk2d/walk2d_medium_dt_seed0', env=dict( - env_id='Walker2d-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), + env_id='Walk2d-v3', collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, - stop_value=6000, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/walk2d_medium-v2.pkl', ), policy=dict( - stop_value=6000, cuda=True, - env_name='Walker2d-v3', - rtg_target=6000, # max target return to go + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walk2d-v3', + rtg_target=5000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/walker2d_medium_dt_log', + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( - state_dim=17, - act_dim=6, + state_dim=11, + act_dim=3, n_blocks=3, h_dim=128, context_len=20, @@ -40,32 +41,19 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/walker2d-medium-v2.pkl', - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), + batch_size=64, + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) -walker2d_dt_config = EasyDict(walker2d_dt_config) -main_config = walker2d_dt_config -walker2d_dt_create_config = dict( +walk2d_dt_config = EasyDict(walk2d_dt_config) +main_config = walk2d_dt_config +walk2d_dt_create_config = dict( env=dict( type='mujoco', import_names=['dizoo.mujoco.envs.mujoco_env'], @@ -73,8 +61,8 @@ env_manager=dict(type='subprocess'), policy=dict(type='dt'), ) -walker2d_dt_create_config = EasyDict(walker2d_dt_create_config) -create_config = walker2d_dt_create_config +walk2d_dt_create_config = EasyDict(walk2d_dt_create_config) +create_config = walk2d_dt_create_config if __name__ == "__main__": from ding.entry import serial_pipeline_dt diff --git a/dizoo/d4rl/config/walker2d_medium_expert_dt_config.py b/dizoo/d4rl/config/walker2d_medium_expert_dt_config.py index deb6051d10..225d00c2e3 100644 --- a/dizoo/d4rl/config/walker2d_medium_expert_dt_config.py +++ b/dizoo/d4rl/config/walker2d_medium_expert_dt_config.py @@ -1,38 +1,39 @@ from easydict import EasyDict from copy import deepcopy -walker2d_dt_config = dict( - exp_name='walker2d_medium_expert_dt_seed0', +walk2d_dt_config = dict( + exp_name='dt_log/d4rl/walk2d/walk2d_medium_expert_dt_seed0', env=dict( - env_id='Walker2d-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), + env_id='Walk2d-v3', collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, - stop_value=6000, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/walk2d_medium_expert-v2.pkl', ), policy=dict( - stop_value=6000, cuda=True, - env_name='Walker2d-v3', - rtg_target=6000, # max target return to go + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walk2d-v3', + rtg_target=5000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/walker2d_medium_expert_dt_log', + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( - state_dim=17, - act_dim=6, + state_dim=11, + act_dim=3, n_blocks=3, h_dim=128, context_len=20, @@ -40,32 +41,19 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/walker2d-medium-expert-v2.pkl', - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), + batch_size=64, + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) -walker2d_dt_config = EasyDict(walker2d_dt_config) -main_config = walker2d_dt_config -walker2d_dt_create_config = dict( +walk2d_dt_config = EasyDict(walk2d_dt_config) +main_config = walk2d_dt_config +walk2d_dt_create_config = dict( env=dict( type='mujoco', import_names=['dizoo.mujoco.envs.mujoco_env'], @@ -73,8 +61,8 @@ env_manager=dict(type='subprocess'), policy=dict(type='dt'), ) -walker2d_dt_create_config = EasyDict(walker2d_dt_create_config) -create_config = walker2d_dt_create_config +walk2d_dt_create_config = EasyDict(walk2d_dt_create_config) +create_config = walk2d_dt_create_config if __name__ == "__main__": from ding.entry import serial_pipeline_dt diff --git a/dizoo/d4rl/config/walker2d_medium_replay_dt_config.py b/dizoo/d4rl/config/walker2d_medium_replay_dt_config.py index eb471c5370..b96375b242 100644 --- a/dizoo/d4rl/config/walker2d_medium_replay_dt_config.py +++ b/dizoo/d4rl/config/walker2d_medium_replay_dt_config.py @@ -1,38 +1,39 @@ from easydict import EasyDict from copy import deepcopy -walker2d_dt_config = dict( - exp_name='walker2d_medium_replay_dt_seed0', +walk2d_dt_config = dict( + exp_name='dt_log/d4rl/walk2d/walk2d_medium_replay_dt_seed0', env=dict( - env_id='Walker2d-v3', - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), + env_id='Walk2d-v3', collector_env_num=1, evaluator_env_num=8, use_act_scale=True, n_evaluator_episode=8, - stop_value=6000, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=30, + data_dir_prefix='d4rl/walk2d_medium_replay-v2.pkl', ), policy=dict( - stop_value=6000, cuda=True, - env_name='Walker2d-v3', - rtg_target=6000, # max target return to go + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walk2d-v3', + rtg_target=5000, # max target return to go max_eval_ep_len=1000, # max lenght of one episode - num_eval_ep=10, # num of evaluation episode - batch_size=64, wt_decay=1e-4, warmup_steps=10000, - num_updates_per_iter=100, context_len=20, - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/walker2d_medium_replay_dt_log', + weight_decay=0.1, + clip_grad_norm_p=0.25, model=dict( - state_dim=17, - act_dim=6, + state_dim=11, + act_dim=3, n_blocks=3, h_dim=128, context_len=20, @@ -40,32 +41,19 @@ drop_p=0.1, continuous=True, ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='/mnt/lustre/wangzilin/d4rl_data/walker2d-medium-replay-v2.pkl', - learning_rate=0.0001, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(evalu_freq=100, ), ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=1000, ), + batch_size=64, + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=100, ), ), ), ) -walker2d_dt_config = EasyDict(walker2d_dt_config) -main_config = walker2d_dt_config -walker2d_dt_create_config = dict( +walk2d_dt_config = EasyDict(walk2d_dt_config) +main_config = walk2d_dt_config +walk2d_dt_create_config = dict( env=dict( type='mujoco', import_names=['dizoo.mujoco.envs.mujoco_env'], @@ -73,8 +61,8 @@ env_manager=dict(type='subprocess'), policy=dict(type='dt'), ) -walker2d_dt_create_config = EasyDict(walker2d_dt_create_config) -create_config = walker2d_dt_create_config +walk2d_dt_create_config = EasyDict(walk2d_dt_create_config) +create_config = walk2d_dt_create_config if __name__ == "__main__": from ding.entry import serial_pipeline_dt From 07601f964b3a6a9a9add6dcb3dc549cda7e783e2 Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 7 Aug 2023 15:30:15 +0800 Subject: [PATCH 10/25] reformat --- ding/envs/env/tests/test_ding_env_wrapper.py | 2 +- ding/envs/env_wrappers/env_wrappers.py | 15 +- ding/example/dt.py | 3 +- .../middleware/functional/data_processor.py | 15 +- .../framework/middleware/functional/logger.py | 6 +- ding/model/template/dt.py | 89 ++++++---- ding/policy/dt.py | 132 +++++++++++---- ding/torch_utils/network/transformer.py | 19 ++- ding/utils/data/dataloader.py | 4 - ding/utils/data/dataset.py | 155 +++++++++--------- .../config/serial/pong/pong_dt_config.py | 4 +- dizoo/atari/entry/spaceinvaders_dqn_eval.py | 5 +- dizoo/atari/example/atari_dqn_dist_ddp.py | 1 - .../config/bipedalwalker_ddpg_config.py | 8 +- .../config/bipedalwalker_sac_config.py | 12 +- .../config/bipedalwalker_td3_config.py | 8 +- .../carracing/config/carracing_dqn_config.py | 9 +- dizoo/box2d/carracing/envs/carracing_env.py | 1 - .../carracing/envs/test_carracing_env.py | 10 +- .../config/lunarlander_cont_sac_config.py | 4 +- .../lunarlander_decision_transformer.py | 2 +- .../config/lunarlander_dt_config.py | 2 +- .../offline_data/collect_dqn_data_config.py | 5 +- .../cartpole/config/cartpole_bc_config.py | 2 +- .../config/mtcar_rainbow_config.py | 95 ++++++----- .../mountain_car/envs/__init__.py | 2 +- .../pendulum/config/pendulum_ibc_config.py | 13 +- .../pendulum/config/pendulum_td3_bc_config.py | 2 +- .../pendulum/entry/pendulum_dqn_eval.py | 5 +- .../config/cliffwalking_dqn_config.py | 2 +- dizoo/cliffwalking/envs/cliffwalking_env.py | 2 +- .../envs/test_cliffwalking_env.py | 1 + .../config/hopper_medium_expert_bc_config.py | 6 +- .../hopper_medium_expert_ibc_ar_config.py | 14 +- .../config/hopper_medium_expert_ibc_config.py | 14 +- .../hopper_medium_expert_ibc_mcmc_config.py | 14 +- .../d4rl/config/kitchen_complete_bc_config.py | 8 +- .../config/kitchen_complete_ibc_ar_config.py | 14 +- .../config/kitchen_complete_ibc_config.py | 14 +- .../kitchen_complete_ibc_mcmc_config.py | 14 +- dizoo/d4rl/config/pen_human_bc_config.py | 6 +- dizoo/d4rl/config/pen_human_ibc_ar_config.py | 14 +- dizoo/d4rl/config/pen_human_ibc_config.py | 14 +- .../d4rl/config/pen_human_ibc_mcmc_config.py | 14 +- dizoo/d4rl/entry/d4rl_cql_main.py | 2 +- dizoo/d4rl/entry/d4rl_dt_mujoco.py | 7 +- dizoo/d4rl/entry/d4rl_td3_bc_main.py | 2 +- dizoo/dmc2gym/config/dmc2gym_ppo_config.py | 1 - dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py | 33 ++-- dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py | 33 ++-- dizoo/dmc2gym/envs/dmc2gym_env.py | 2 + dizoo/dmc2gym/envs/test_dmc2gym_env.py | 1 - .../evogym/envs/test/visualize_simple_env.py | 1 - .../config/stocks_dqn_config.py | 6 +- .../worker/trading_serial_evaluator.py | 26 +-- .../envs/gym-hybrid/gym_hybrid/__init__.py | 3 +- dizoo/gym_hybrid/envs/gym-hybrid/setup.py | 9 +- .../envs/gym-hybrid/tests/moving.py | 1 - dizoo/gym_hybrid/envs/test_gym_hybrid_env.py | 12 +- .../entry/imagenet_res18_config.py | 4 +- dizoo/league_demo/league_demo_collector.py | 14 +- dizoo/maze/entry/maze_bc_main.py | 14 +- dizoo/minigrid/utils/eval.py | 10 +- dizoo/mujoco/config/halfcheetah_bdq_config.py | 7 +- dizoo/mujoco/config/hopper_bdq_config.py | 6 +- dizoo/mujoco/envs/mujoco_wrappers.py | 8 +- .../config/ant_mappo_config.py | 1 - .../config/ant_masac_config.py | 4 +- .../config/ptz_simple_spread_madqn_config.py | 8 +- dizoo/rocket/entry/rocket_hover_ppo_main.py | 6 +- dizoo/rocket/entry/rocket_landing_ppo_main.py | 8 +- dizoo/rocket/envs/test_rocket_env.py | 6 +- dizoo/smac/config/smac_3s5z_madqn_config.py | 12 +- .../config/smac_3s5zvs3s6z_madqn_config.py | 12 +- dizoo/smac/config/smac_5m6m_madqn_config.py | 11 +- dizoo/smac/config/smac_8m9m_madqn_config.py | 11 +- dizoo/smac/config/smac_MMM2_madqn_config.py | 12 +- dizoo/smac/config/smac_MMM_madqn_config.py | 12 +- dizoo/smac/utils/eval.py | 10 +- 79 files changed, 555 insertions(+), 541 deletions(-) diff --git a/ding/envs/env/tests/test_ding_env_wrapper.py b/ding/envs/env/tests/test_ding_env_wrapper.py index 5b98a9403c..a99e4ace23 100644 --- a/ding/envs/env/tests/test_ding_env_wrapper.py +++ b/ding/envs/env/tests/test_ding_env_wrapper.py @@ -194,7 +194,7 @@ def test_AllinObsWrapper(self): action = ding_env_aio.random_action() timestep = ding_env_aio.step(action) # print(timestep.reward) - assert isinstance(timestep.obs,dict) + assert isinstance(timestep.obs, dict) if timestep.done: assert 'eval_episode_return' in timestep.info, timestep.info break diff --git a/ding/envs/env_wrappers/env_wrappers.py b/ding/envs/env_wrappers/env_wrappers.py index f0442509bd..8f8af4ddec 100644 --- a/ding/envs/env_wrappers/env_wrappers.py +++ b/ding/envs/env_wrappers/env_wrappers.py @@ -1192,22 +1192,25 @@ def __init__(self, env): super().__init__(env) def reset(self): - ret = {'obs':self.env.reset(), 'reward': np.array([0])} - self._observation_space = gym.spaces.Dict({ - 'obs': self.env.observation_space, - 'reward': gym.spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32)} + ret = {'obs': self.env.reset(), 'reward': np.array([0])} + self._observation_space = gym.spaces.Dict( + { + 'obs': self.env.observation_space, + 'reward': gym.spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32) + } ) return ret def step(self, action): obs, reward, done, info = self.env.step(action) - obs = {'obs':obs, 'reward': reward} + obs = {'obs': obs, 'reward': reward} from ding.envs import BaseEnvTimestep return BaseEnvTimestep(obs, reward, done, info) - + def seed(self, seed: int, dynamic_seed: bool = True) -> None: self.env.seed(seed, dynamic_seed) + def update_shape(obs_shape, act_shape, rew_shape, wrapper_names): """ Overview: diff --git a/ding/example/dt.py b/ding/example/dt.py index 30969ce5da..10884c3ec7 100644 --- a/ding/example/dt.py +++ b/ding/example/dt.py @@ -22,7 +22,8 @@ def main(): ding_init(cfg) with task.start(async_mode=False, ctx=OfflineRLContext()): evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: AllinObsWrapper(LunarLanderEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + env_fn=[lambda: AllinObsWrapper(LunarLanderEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager ) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 1440150123..0c008370c0 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -186,6 +186,7 @@ def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: from queue import Queue import time stream = torch.cuda.Stream() + def producer(queue, dataset, batch_size, device): torch.set_num_threads(8) nonlocal stream @@ -201,12 +202,17 @@ def producer(queue, dataset, batch_size, device): del idx_iter idx_iter = iter(range(len(dataset))) start_idx = next(idx_iter) - data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx+batch_size)] + data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)] data = [[i[j] for i in data] for j in range(len(data[0]))] data = [torch.stack(x).to(device) for x in data] queue.put(data) + queue = Queue(maxsize=50) - producer_thread = Thread(target=producer, args=(queue, dataset, cfg.policy.batch_size, 'cuda:0' if cfg.policy.cuda else 'cpu'), name='cuda_fetcher_producer') + producer_thread = Thread( + target=producer, + args=(queue, dataset, cfg.policy.batch_size, 'cuda:0,1' if cfg.policy.cuda else 'cpu'), + name='cuda_fetcher_producer' + ) producer_thread.start() @@ -249,9 +255,12 @@ def _fetch(ctx: "OfflineRLContext"): except StopIteration: ctx.train_epoch += 1 del dataloader - dataloader = iter(DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)) + dataloader = iter( + DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x) + ) ctx.train_data = next(dataloader) # TODO apply data update (e.g. priority) in offline setting when necessary + return _fetch diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index dfecfb7d22..4e760ffbc3 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -93,12 +93,10 @@ def _logger(ctx: "OnlineRLContext"): return _logger -def offline_logger( - exp_name: str = None -) -> Callable: +def offline_logger(exp_name: str = None) -> Callable: if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() - writer = SummaryWriter(logdir = exp_name) + writer = SummaryWriter(logdir=exp_name) def _logger(ctx: "OfflineRLContext"): if task.finish: diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index 7b5a941ebe..8b898dcc89 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -19,6 +19,7 @@ class MaskedCausalAttention(nn.Module): + def __init__(self, h_dim, max_T, n_heads, drop_p): super().__init__() @@ -39,22 +40,22 @@ def __init__(self, h_dim, max_T, n_heads, drop_p): # register buffer makes sure mask does not get updated # during backpropagation - self.register_buffer('mask',mask) + self.register_buffer('mask', mask) def forward(self, x): - B, T, C = x.shape # batch size, seq length, h_dim * n_heads + B, T, C = x.shape # batch size, seq length, h_dim * n_heads - N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim + N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim # rearrange q, k, v as (B, N, T, D) - q = self.q_net(x).view(B, T, N, D).transpose(1,2) - k = self.k_net(x).view(B, T, N, D).transpose(1,2) - v = self.v_net(x).view(B, T, N, D).transpose(1,2) + q = self.q_net(x).view(B, T, N, D).transpose(1, 2) + k = self.k_net(x).view(B, T, N, D).transpose(1, 2) + v = self.v_net(x).view(B, T, N, D).transpose(1, 2) # weights (B, N, T, T) - weights = q @ k.transpose(2,3) / math.sqrt(D) + weights = q @ k.transpose(2, 3) / math.sqrt(D) # causal mask applied to weights - weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf')) + weights = weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf')) # normalize weights, all -inf -> 0 after softmax normalized_weights = F.softmax(weights, dim=-1) @@ -62,37 +63,50 @@ def forward(self, x): attention = self.att_drop(normalized_weights @ v) # gather heads and project (B, N, T, D) -> (B, T, N*D) - attention = attention.transpose(1, 2).contiguous().view(B,T,N*D) + attention = attention.transpose(1, 2).contiguous().view(B, T, N * D) out = self.proj_drop(self.proj_net(attention)) return out class Block(nn.Module): + def __init__(self, h_dim, max_T, n_heads, drop_p): super().__init__() self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) self.mlp = nn.Sequential( - nn.Linear(h_dim, 4*h_dim), - nn.GELU(), - nn.Linear(4*h_dim, h_dim), - nn.Dropout(drop_p), - ) + nn.Linear(h_dim, 4 * h_dim), + nn.GELU(), + nn.Linear(4 * h_dim, h_dim), + nn.Dropout(drop_p), + ) self.ln1 = nn.LayerNorm(h_dim) self.ln2 = nn.LayerNorm(h_dim) def forward(self, x): # Attention -> LayerNorm -> MLP -> LayerNorm - x = x + self.attention(x) # residual + x = x + self.attention(x) # residual x = self.ln1(x) - x = x + self.mlp(x) # residual + x = x + self.mlp(x) # residual x = self.ln2(x) return x class DecisionTransformer(nn.Module): - def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len, - n_heads, drop_p, max_timestep=4096, state_encoder=None, continuous=False): + + def __init__( + self, + state_dim, + act_dim, + n_blocks, + h_dim, + context_len, + n_heads, + drop_p, + max_timestep=4096, + state_encoder=None, + continuous=False + ): super().__init__() self.state_dim = state_dim @@ -110,7 +124,7 @@ def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len, self.embed_rtg = torch.nn.Linear(1, h_dim) self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) - self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep+1, self.h_dim)) + self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim)) if state_encoder == None: self.embed_state = torch.nn.Linear(state_dim, h_dim) @@ -122,16 +136,14 @@ def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len, if continuous: # continuous actions self.embed_action = torch.nn.Linear(act_dim, h_dim) - use_action_tanh = True # True for continuous actions + use_action_tanh = True # True for continuous actions else: # discrete actions self.embed_action = torch.nn.Embedding(act_dim, h_dim) - use_action_tanh = False # False for discrete actions + use_action_tanh = False # False for discrete actions ### prediction heads - self.predict_action = nn.Sequential( - *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) - ) + self.predict_action = nn.Sequential(*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))) def forward(self, timesteps, states, actions, returns_to_go): B, T = states.shape[0], states.shape[1] @@ -143,19 +155,24 @@ def forward(self, timesteps, states, actions, returns_to_go): action_embeddings = self.embed_action(actions) + time_embeddings returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings else: - state_embeddings = self.state_encoder(states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()) # (batch * block_size, h_dim) - state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.h_dim) # (batch, block_size, h_dim) + state_embeddings = self.state_encoder( + states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous() + ) # (batch * block_size, h_dim) + state_embeddings = state_embeddings.reshape( + states.shape[0], states.shape[1], self.h_dim + ) # (batch, block_size, h_dim) returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32)) - action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim) + action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim) # stack rtg, states and actions and reshape sequence as # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) - h = torch.stack( - (returns_embeddings, state_embeddings, action_embeddings), dim=1 - ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) + h = torch.stack((returns_embeddings, state_embeddings, action_embeddings), + dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) - all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, B, dim=0) # batch_size, traj_length, h_dim - position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1)) + all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, B, dim=0) # batch_size, traj_length, h_dim + position_embeddings = torch.gather( + all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1) + ) position_embeddings = position_embeddings + self.pos_emb[:, :h.shape[1], :] h = self.embed_ln(h + position_embeddings) @@ -168,17 +185,17 @@ def forward(self, timesteps, states, actions, returns_to_go): # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t # that is, for each timestep (t) we have 3 output embeddings from the transformer, - # each conditioned on all previous timesteps plus + # each conditioned on all previous timesteps plus # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) # get predictions if self.state_encoder == None: - return_preds = self.predict_rtg(h[:,2]) # predict next rtg given r, s, a - state_preds = self.predict_state(h[:,2]) # predict next state given r, s, a + return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a + state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a else: return_preds = None state_preds = None - action_preds = self.predict_action(h[:,1]) # predict action given r, s + action_preds = self.predict_action(h[:, 1]) # predict action given r, s return state_preds, action_preds, return_preds diff --git a/ding/policy/dt.py b/ding/policy/dt.py index 97a78a2be1..f478a1b78d 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -35,9 +35,9 @@ class DTPolicy(Policy): max_eval_ep_len=1000, # max len of one episode batch_size=64, # training batch size wt_decay=1e-4, # decay weight in optimizer - warmup_steps=10000, # steps for learning rate warmup + warmup_steps=10000, # steps for learning rate warmup context_len=20, # length of transformer input - learning_rate=1e-4, + learning_rate=1e-4, ) def default_model(self) -> Tuple[str, List[str]]: @@ -54,14 +54,14 @@ def _init_learn(self) -> None: # rtg_target: max target of `return to go` # Our goal is normalize `return to go` to (0, 1), which will favour the covergence. # As a result, we usually set rtg_scale == rtg_target. - self.rtg_scale = self._cfg.rtg_scale # normalize returns to go + self.rtg_scale = self._cfg.rtg_scale # normalize returns to go self.rtg_target = self._cfg.rtg_target # max target reward_to_go self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode lr = self._cfg.learning_rate # learning rate wt_decay = self._cfg.wt_decay # weight decay warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler - + self.clip_grad_norm_p = self._cfg.clip_grad_norm_p self.context_len = self._cfg.model.context_len # K in decision transformer @@ -86,7 +86,7 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: Returns: - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. """ - + timesteps, states, actions, returns_to_go, traj_mask = data action_target = torch.clone(actions).detach().to(self._device) @@ -98,9 +98,10 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: # if discrete if not self._cfg.model.continuous and 'state_mean' in self._cfg: actions = one_hot(actions.squeeze(-1), num=self.act_dim) - + state_preds, action_preds, return_preds = self._learn_model.forward( - timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go) + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go + ) if 'state_mean' not in self._cfg: action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) @@ -145,40 +146,72 @@ def _init_eval(self) -> None: self.eval_batch_size = self._cfg.evaluator_env_num self.max_eval_ep_len = self._cfg.max_eval_ep_len self.context_len = self._cfg.model.context_len # K in decision transformer - + self.t = [0 for _ in range(self.eval_batch_size)] if self._cfg.model.continuous: - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device) + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device + ) else: - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device + ) if 'state_mean' not in self._cfg: - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self._device) + self.states = torch.zeros( + ( + self.eval_batch_size, + self.max_eval_ep_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] else: self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device) + self.states = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device + ) self.state_mean = torch.from_numpy(self._cfg.state_mean).to(self._device) self.state_std = torch.from_numpy(self._cfg.state_std).to(self._device) - self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self._device) - self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) + self.timesteps = torch.arange( + start=0, end=self.max_eval_ep_len, step=1 + ).repeat(self.eval_batch_size, 1).to(self._device) + self.rewards_to_go = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device + ) def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: # save and forward data_id = list(data.keys()) - + self._eval_model.eval() with torch.no_grad(): if 'state_mean' not in self._cfg: - states = torch.zeros((self.eval_batch_size, self.context_len,) + tuple(self.state_dim), dtype=torch.float32, device=self._device) + states = torch.zeros( + ( + self.eval_batch_size, + self.context_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device) else: - states = torch.zeros((self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device) + states = torch.zeros( + (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device + ) timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device) if not self._cfg.model.continuous: - actions = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device) + actions = torch.zeros( + (self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device + ) else: - actions = torch.zeros((self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device) - rewards_to_go = torch.zeros((self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device) + actions = torch.zeros( + (self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device + ) + rewards_to_go = torch.zeros( + (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device + ) for i in data_id: if 'state_mean' not in self._cfg: self.states[i, self.t[i]] = data[i]['obs'].to(self._device) @@ -186,10 +219,12 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device) self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] - + if self.t[i] <= self.context_len: if 'state_mean' not in self._cfg: - timesteps[i] = min(self.t[i], self._cfg.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device) + timesteps[i] = min(self.t[i], + self._cfg.max_timestep) * torch.ones((1, 1), + dtype=torch.int64).to(self._device) else: timesteps[i] = self.timesteps[i, :self.context_len] states[i] = self.states[i, :self.context_len] @@ -197,7 +232,9 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] else: if 'state_mean' not in self._cfg: - timesteps[i] = min(self.t[i], self._cfg.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device) + timesteps[i] = min(self.t[i], + self._cfg.max_timestep) * torch.ones((1, 1), + dtype=torch.int64).to(self._device) else: timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] @@ -225,31 +262,58 @@ def _reset_eval(self, data_id: List[int] = None) -> None: # clean data if data_id is None: self.t = [0 for _ in range(self.eval_batch_size)] - self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self._device) + self.timesteps = torch.arange( + start=0, end=self.max_eval_ep_len, step=1 + ).repeat(self.eval_batch_size, 1).to(self._device) if not self._cfg.model.continuous: - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device + ) else: - self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device) + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), + dtype=torch.float32, + device=self._device + ) if 'state_mean' not in self._cfg: - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self._device) + self.states = torch.zeros( + ( + self.eval_batch_size, + self.max_eval_ep_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] else: - self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device) + self.states = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), + dtype=torch.float32, + device=self._device + ) self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] - self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) + self.rewards_to_go = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device + ) else: for i in data_id: self.t[i] = 0 if not self._cfg.model.continuous: self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) else: - self.actions[i] = torch.zeros((self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device) + self.actions[i] = torch.zeros( + (self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device + ) if 'state_mean' not in self._cfg: - self.states[i] = torch.zeros((self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self._device) + self.states[i] = torch.zeros( + (self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device + ) self.running_rtg[i] = self.rtg_target else: - self.states[i] = torch.zeros((self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device) + self.states[i] = torch.zeros( + (self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device + ) self.running_rtg[i] = self.rtg_target / self.rtg_scale self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device) self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) @@ -264,7 +328,7 @@ def _state_dict_learn(self) -> Dict[str, Any]: def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._learn_model.load_state_dict(state_dict['model']) self._optimizer.load_state_dict(state_dict['optimizer']) - + def _load_state_dict_eval(self, state_dict: Dict[str, Any]) -> None: self._eval_model.load_state_dict(state_dict) @@ -281,4 +345,4 @@ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: pass def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: - pass \ No newline at end of file + pass diff --git a/ding/torch_utils/network/transformer.py b/ding/torch_utils/network/transformer.py index d93457ba9a..a5da9ac6ce 100644 --- a/ding/torch_utils/network/transformer.py +++ b/ding/torch_utils/network/transformer.py @@ -82,6 +82,7 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch class MaskedCausalAttention(nn.Module): + def __init__(self, h_dim, max_T, n_heads, drop_p): super().__init__() @@ -102,22 +103,22 @@ def __init__(self, h_dim, max_T, n_heads, drop_p): # register buffer makes sure mask does not get updated # during backpropagation - self.register_buffer('mask',mask) + self.register_buffer('mask', mask) def forward(self, x): - B, T, C = x.shape # batch size, seq length, h_dim * n_heads + B, T, C = x.shape # batch size, seq length, h_dim * n_heads - N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim + N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim # rearrange q, k, v as (B, N, T, D) - q = self.q_net(x).view(B, T, N, D).transpose(1,2) - k = self.k_net(x).view(B, T, N, D).transpose(1,2) - v = self.v_net(x).view(B, T, N, D).transpose(1,2) + q = self.q_net(x).view(B, T, N, D).transpose(1, 2) + k = self.k_net(x).view(B, T, N, D).transpose(1, 2) + v = self.v_net(x).view(B, T, N, D).transpose(1, 2) # weights (B, N, T, T) - weights = q @ k.transpose(2,3) / math.sqrt(D) + weights = q @ k.transpose(2, 3) / math.sqrt(D) # causal mask applied to weights - weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf')) + weights = weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf')) # normalize weights, all -inf -> 0 after softmax normalized_weights = F.softmax(weights, dim=-1) @@ -125,7 +126,7 @@ def forward(self, x): attention = self.att_drop(normalized_weights @ v) # gather heads and project (B, N, T, D) -> (B, T, N*D) - attention = attention.transpose(1, 2).contiguous().view(B,T,N*D) + attention = attention.transpose(1, 2).contiguous().view(B, T, N * D) out = self.proj_drop(self.proj_net(attention)) return out diff --git a/ding/utils/data/dataloader.py b/ding/utils/data/dataloader.py index bcc70aac63..0670d0db7f 100644 --- a/ding/utils/data/dataloader.py +++ b/ding/utils/data/dataloader.py @@ -161,11 +161,7 @@ def _get_data(self, p: tm.multiprocessing.connection, c: tm.multiprocessing.conn break if cmd == 'get_data': # Main worker asks for data. - import time - st = time.time() - print('in data get at', st) data = self.data_source(self.batch_size) - print('already get data at', time.time(), 'cost', time.time()-st) # ``data`` can be callable, e.g. a function to read data from file, therefore we can divide # this job to pieces, assign to every slave worker and accomplish jobs asynchronously. # But if we get a list of dicts, which means the data has already been processed and diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 371631bc37..bfbc777f4d 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -334,7 +334,7 @@ def __init__(self, cfg: dict) -> None: self.context_len = cfg.dataset.context_len self.env_type = cfg.dataset.env_type - if 'hdf5' in dataset_path: # for mujoco env + if 'hdf5' in dataset_path: # for mujoco env try: import h5py import collections @@ -358,7 +358,7 @@ def __init__(self, cfg: dict) -> None: if use_timeouts: final_timestep = dataset['timeouts'][i] else: - final_timestep = (episode_step == 1000-1) + final_timestep = (episode_step == 1000 - 1) for k in ['observations', 'actions', 'rewards', 'terminals']: data_[k].append(dataset[k][i]) if done_bool or final_timestep: @@ -374,7 +374,7 @@ def __init__(self, cfg: dict) -> None: # calculate min len of traj, state mean and variance # and returns_to_go for all traj - min_len = 10**6 + min_len = 10 ** 6 states = [] for traj in self.trajectories: traj_len = traj['observations'].shape[0] @@ -399,7 +399,7 @@ def __init__(self, cfg: dict) -> None: # self.trajectories[k] = np.expand_dims(dataset[k][:], axis=1) # else: # self.trajectories[k] = dataset[k][:] - + # # used for input normalization # states = np.concatenate(self.trajectories['observations'], axis=0) # self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 @@ -413,7 +413,7 @@ def __init__(self, cfg: dict) -> None: # use_timeouts = False # if 'timeouts' in dataset: # use_timeouts = True - + # data_ = collections.defaultdict(list) # episode_step = 0 # trajectories_tmp = [] @@ -451,7 +451,8 @@ def __init__(self, cfg: dict) -> None: for transition_index in range(len(self.trajectories[eps_index])) ], axis=0 - ) for key, o_key in zip(keys, original_keys) + ) + for key, o_key in zip(keys, original_keys) } for eps_index in range(len(self.trajectories)) ] self.trajectories = trajectories_tmp @@ -475,7 +476,7 @@ def __init__(self, cfg: dict) -> None: with open(dataset_path, 'rb') as f: self.trajectories = pickle.load(f) - min_len = 10**6 + min_len = 10 ** 6 states = [] for traj in self.trajectories: traj_len = traj['observations'].shape[0] @@ -514,14 +515,17 @@ def __init__(self, cfg: dict) -> None: gamma=0.99, observation_dtype=np.uint8, batch_size=32, - replay_capacity=100000) + replay_capacity=100000 + ) if frb._loaded_buffers: done = False curr_num_transitions = len(obss) trajectories_to_load = cfg.dataset.trajectories_per_buffer while not done: - states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i]) - states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) + states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch( + batch_size=1, indices=[i] + ) + states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) obss.append(states) actions.append(ac[0]) stepwise_returns.append(ret[0]) @@ -543,7 +547,10 @@ def __init__(self, cfg: dict) -> None: done = True num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load) transitions_per_buffer[buffer_num] = i - print('this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' % (i, len(obss), num_trajectories)) + print( + 'this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' + % (i, len(obss), num_trajectories) + ) actions = np.array(actions) returns = np.array(returns) @@ -556,18 +563,18 @@ def __init__(self, cfg: dict) -> None: for i in done_idxs: i = int(i) curr_traj_returns = stepwise_returns[start_index:i] - for j in range(i-1, start_index-1, -1): # start from i-1 - rtg_j = curr_traj_returns[j-start_index:i-start_index] + for j in range(i - 1, start_index - 1, -1): # start from i-1 + rtg_j = curr_traj_returns[j - start_index:i - start_index] rtg[j] = sum(rtg_j) start_index = i # -- create timestep dataset start_index = 0 - timesteps = np.zeros(len(actions)+1, dtype=int) + timesteps = np.zeros(len(actions) + 1, dtype=int) for i in done_idxs: i = int(i) - timesteps[start_index:i+1] = np.arange(i+1 - start_index) - start_index = i+1 + timesteps[start_index:i + 1] = np.arange(i + 1 - start_index) + start_index = i + 1 self.obss = obss self.actions = actions @@ -575,7 +582,7 @@ def __init__(self, cfg: dict) -> None: self.rtgs = rtg self.timesteps = timesteps # return obss, actions, returns, done_idxs, rtg, timesteps - + def get_max_timestep(self) -> int: return max(self.timesteps) @@ -635,31 +642,33 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso traj_mask = torch.cat( [torch.ones(traj_len, dtype=torch.long), - torch.zeros(padding_len, dtype=torch.long)], dim=0 + torch.zeros(padding_len, dtype=torch.long)], dim=0 ) return timesteps, states, actions, returns_to_go, traj_mask - else: # mean cost less than 0.001s + else: # mean cost less than 0.001s block_size = self.context_len done_idx = idx + block_size for i in self.done_idxs: - if i > idx: # first done_idx greater than idx + if i > idx: # first done_idx greater than idx done_idx = min(int(i), done_idx) break idx = done_idx - block_size - states = torch.as_tensor(np.array(self.obss[idx:done_idx]), dtype=torch.float32).view(block_size, -1) # (block_size, 4*84*84) + states = torch.as_tensor( + np.array(self.obss[idx:done_idx]), dtype=torch.float32 + ).view(block_size, -1) # (block_size, 4*84*84) states = states / 255. - actions = torch.as_tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) + actions = torch.as_tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) rtgs = torch.as_tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) - timesteps = torch.as_tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1) + timesteps = torch.as_tensor(self.timesteps[idx:idx + 1], dtype=torch.int64).unsqueeze(1) traj_mask = torch.ones(self.context_len, dtype=torch.long) return timesteps, states, actions, rtgs, traj_mask - + class FixedReplayBuffer(object): - """Object composed of a list of OutofGraphReplayBuffers.""" + """Object composed of a list of OutofGraphReplayBuffers.""" - def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg - """Initialize the FixedReplayBuffer class. + def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg + """Initialize the FixedReplayBuffer class. Args: data_dir: str, log Directory from which to load the replay buffer. replay_suffix: int, If not None, then only load the replay buffer @@ -667,53 +676,51 @@ def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable *args: Arbitrary extra arguments. **kwargs: Arbitrary keyword arguments. """ - self._args = args - self._kwargs = kwargs - self._data_dir = data_dir - self._loaded_buffers = False - self.add_count = np.array(0) - self._replay_suffix = replay_suffix - if not self._loaded_buffers: - if replay_suffix is not None: - assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' - self.load_single_buffer(replay_suffix) - else: - pass - # self._load_replay_buffers(num_buffers=50) - - def load_single_buffer(self, suffix): - """Load a single replay buffer.""" - replay_buffer = self._load_buffer(suffix) - if replay_buffer is not None: - self._replay_buffers = [replay_buffer] - self.add_count = replay_buffer.add_count - self._num_replay_buffers = 1 - self._loaded_buffers = True - - def _load_buffer(self, suffix): - """Loads a OutOfGraphReplayBuffer replay buffer.""" - try: - from dopamine.replay_memory import circular_replay_buffer - STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX - # pytype: disable=attribute-error - replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer( - *self._args, **self._kwargs) - replay_buffer.load(self._data_dir, suffix) - print('Loaded replay buffer ckpt {} from {}'.format( - suffix, self._data_dir)) - # pytype: enable=attribute-error - return replay_buffer - # except tf.errors.NotFoundError: - except: - raise('can not load') - - def get_transition_elements(self): - return self._replay_buffers[0].get_transition_elements() - - def sample_transition_batch(self, batch_size=None, indices=None): - buffer_index = np.random.randint(self._num_replay_buffers) - return self._replay_buffers[buffer_index].sample_transition_batch( - batch_size=batch_size, indices=indices) + self._args = args + self._kwargs = kwargs + self._data_dir = data_dir + self._loaded_buffers = False + self.add_count = np.array(0) + self._replay_suffix = replay_suffix + if not self._loaded_buffers: + if replay_suffix is not None: + assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' + self.load_single_buffer(replay_suffix) + else: + pass + # self._load_replay_buffers(num_buffers=50) + + def load_single_buffer(self, suffix): + """Load a single replay buffer.""" + replay_buffer = self._load_buffer(suffix) + if replay_buffer is not None: + self._replay_buffers = [replay_buffer] + self.add_count = replay_buffer.add_count + self._num_replay_buffers = 1 + self._loaded_buffers = True + + def _load_buffer(self, suffix): + """Loads a OutOfGraphReplayBuffer replay buffer.""" + try: + from dopamine.replay_memory import circular_replay_buffer + STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX + # pytype: disable=attribute-error + replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(*self._args, **self._kwargs) + replay_buffer.load(self._data_dir, suffix) + print('Loaded replay buffer ckpt {} from {}'.format(suffix, self._data_dir)) + # pytype: enable=attribute-error + return replay_buffer + # except tf.errors.NotFoundError: + except: + raise ('can not load') + + def get_transition_elements(self): + return self._replay_buffers[0].get_transition_elements() + + def sample_transition_batch(self, batch_size=None, indices=None): + buffer_index = np.random.randint(self._num_replay_buffers) + return self._replay_buffers[buffer_index].sample_transition_batch(batch_size=batch_size, indices=indices) + class PCDataset(Dataset): diff --git a/dizoo/atari/config/serial/pong/pong_dt_config.py b/dizoo/atari/config/serial/pong/pong_dt_config.py index 9d9387036d..d4e8222434 100644 --- a/dizoo/atari/config/serial/pong/pong_dt_config.py +++ b/dizoo/atari/config/serial/pong/pong_dt_config.py @@ -12,7 +12,7 @@ stop_value=20, frame_stack=4, is_train=False, - episode_num=10000, # stop in breakout + episode_num=10000, # stop in breakout ), dataset=dict( env_type='atari', @@ -44,7 +44,7 @@ n_heads=8, drop_p=0.1, continuous=False, - ), + ), batch_size=128, learning_rate=6e-4, eval=dict(evaluator=dict(eval_freq=100, ), ), diff --git a/dizoo/atari/entry/spaceinvaders_dqn_eval.py b/dizoo/atari/entry/spaceinvaders_dqn_eval.py index d8bfde290d..35e15a578c 100644 --- a/dizoo/atari/entry/spaceinvaders_dqn_eval.py +++ b/dizoo/atari/entry/spaceinvaders_dqn_eval.py @@ -15,8 +15,9 @@ from ding.rl_utils import get_epsilon_greedy_fn from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_dqn_config import main_config, create_config + def main(rl_cfg, seed=0): - main_cfg, create_cfg =rl_cfg + main_cfg, create_cfg = rl_cfg cfg = compile_config( main_cfg, BaseEnvManager, @@ -56,4 +57,4 @@ def main(rl_cfg, seed=0): if __name__ == "__main__": - main(rl_cfg=(main_config, create_config),seed=0) + main(rl_cfg=(main_config, create_config), seed=0) diff --git a/dizoo/atari/example/atari_dqn_dist_ddp.py b/dizoo/atari/example/atari_dqn_dist_ddp.py index f194c326bc..5dbfc4e65c 100644 --- a/dizoo/atari/example/atari_dqn_dist_ddp.py +++ b/dizoo/atari/example/atari_dqn_dist_ddp.py @@ -14,7 +14,6 @@ from dizoo.atari.envs.atari_env import AtariEnv from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config - logging.getLogger().setLevel(logging.INFO) main_config.exp_name = 'pong_dqn_seed0_ditask_dist_ddp' diff --git a/dizoo/box2d/bipedalwalker/config/bipedalwalker_ddpg_config.py b/dizoo/box2d/bipedalwalker/config/bipedalwalker_ddpg_config.py index 492713b01a..de70a09c86 100644 --- a/dizoo/box2d/bipedalwalker/config/bipedalwalker_ddpg_config.py +++ b/dizoo/box2d/bipedalwalker/config/bipedalwalker_ddpg_config.py @@ -29,13 +29,9 @@ learning_rate_critic=0.0003, target_theta=0.005, discount_factor=0.99, - learner=dict( - hook=dict(log_show_after_iter=1000, ) - ) - ), - collect=dict( - n_sample=64, + learner=dict(hook=dict(log_show_after_iter=1000, )) ), + collect=dict(n_sample=64, ), other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ), ), ) diff --git a/dizoo/box2d/bipedalwalker/config/bipedalwalker_sac_config.py b/dizoo/box2d/bipedalwalker/config/bipedalwalker_sac_config.py index 5d00178359..f905c4031b 100644 --- a/dizoo/box2d/bipedalwalker/config/bipedalwalker_sac_config.py +++ b/dizoo/box2d/bipedalwalker/config/bipedalwalker_sac_config.py @@ -31,13 +31,9 @@ target_theta=0.005, discount_factor=0.99, auto_alpha=True, - learner=dict( - hook=dict(log_show_after_iter=1000, ) - ) - ), - collect=dict( - n_sample=64, + learner=dict(hook=dict(log_show_after_iter=1000, )) ), + collect=dict(n_sample=64, ), other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ), ), ) @@ -49,9 +45,7 @@ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='sac', - ), + policy=dict(type='sac', ), replay_buffer=dict(type='naive', ), ) bipedalwalker_sac_create_config = EasyDict(bipedalwalker_sac_create_config) diff --git a/dizoo/box2d/bipedalwalker/config/bipedalwalker_td3_config.py b/dizoo/box2d/bipedalwalker/config/bipedalwalker_td3_config.py index 95b846b051..09cc3d1bf1 100644 --- a/dizoo/box2d/bipedalwalker/config/bipedalwalker_td3_config.py +++ b/dizoo/box2d/bipedalwalker/config/bipedalwalker_td3_config.py @@ -36,13 +36,9 @@ min=-0.5, max=0.5, ), - learner=dict( - hook=dict(log_show_after_iter=1000, ) - ) - ), - collect=dict( - n_sample=64, + learner=dict(hook=dict(log_show_after_iter=1000, )) ), + collect=dict(n_sample=64, ), other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ), ), ) diff --git a/dizoo/box2d/carracing/config/carracing_dqn_config.py b/dizoo/box2d/carracing/config/carracing_dqn_config.py index 31dd42fca8..1792056a83 100644 --- a/dizoo/box2d/carracing/config/carracing_dqn_config.py +++ b/dizoo/box2d/carracing/config/carracing_dqn_config.py @@ -29,17 +29,14 @@ learning_rate=0.0001, target_update_freq=100, ), - collect=dict( - n_sample=64, - ), + collect=dict(n_sample=64, ), other=dict( eps=dict( type='exp', start=0.95, end=0.1, decay=50000, - ), - replay_buffer=dict(replay_buffer_size=100000, ) + ), replay_buffer=dict(replay_buffer_size=100000, ) ), ), ) @@ -60,4 +57,4 @@ if __name__ == "__main__": # or you can enter `ding -m serial -c carracing_dqn_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline([main_config, create_config], seed=0) \ No newline at end of file + serial_pipeline([main_config, create_config], seed=0) diff --git a/dizoo/box2d/carracing/envs/carracing_env.py b/dizoo/box2d/carracing/envs/carracing_env.py index 39b82a2502..60ebaa97d1 100644 --- a/dizoo/box2d/carracing/envs/carracing_env.py +++ b/dizoo/box2d/carracing/envs/carracing_env.py @@ -2,7 +2,6 @@ import copy import os - import gym import numpy as np from easydict import EasyDict diff --git a/dizoo/box2d/carracing/envs/test_carracing_env.py b/dizoo/box2d/carracing/envs/test_carracing_env.py index 7eb4a75039..47a5fa4638 100644 --- a/dizoo/box2d/carracing/envs/test_carracing_env.py +++ b/dizoo/box2d/carracing/envs/test_carracing_env.py @@ -5,15 +5,7 @@ @pytest.mark.envtest -@pytest.mark.parametrize( - 'cfg', [ - EasyDict({ - 'env_id': 'CarRacing-v2', - 'continuous': False, - 'act_scale': False - }) - ] -) +@pytest.mark.parametrize('cfg', [EasyDict({'env_id': 'CarRacing-v2', 'continuous': False, 'act_scale': False})]) class TestCarRacing: def test_naive(self, cfg): diff --git a/dizoo/box2d/lunarlander/config/lunarlander_cont_sac_config.py b/dizoo/box2d/lunarlander/config/lunarlander_cont_sac_config.py index 0e60fce608..f8a8ab47e7 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_cont_sac_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_cont_sac_config.py @@ -28,9 +28,7 @@ learning_rate_alpha=3e-4, auto_alpha=True, ), - collect=dict( - n_sample=256, - ), + collect=dict(n_sample=256, ), eval=dict(evaluator=dict(eval_freq=1000, ), ), other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ), ), diff --git a/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py b/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py index 1cd9ed2018..cd3b2884e6 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py @@ -27,7 +27,7 @@ embed_dim=128, n_heads=1, dropout_p=0.1, - log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', + log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', model=dict( state_dim=8, act_dim=4, diff --git a/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py b/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py index eebbf7d509..3ea4022727 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py @@ -23,7 +23,7 @@ warmup_steps=10000, context_len=20, # TODO evaluator_env_num=8, - log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', + log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', model=dict( state_dim=8, act_dim=4, diff --git a/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py b/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py index 26b35c189d..e7cc7b383d 100644 --- a/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py +++ b/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py @@ -34,14 +34,15 @@ dataloader=dict(num_workers=0, ), log_policy=True, hook=dict( - load_ckpt_before_run='./ckpt_best.pth.tar', # TODO: syspath modeified in other place, have to use abs path. May be fix in next version. + load_ckpt_before_run= + './ckpt_best.pth.tar', # TODO: syspath modeified in other place, have to use abs path. May be fix in next version. # load_ckpt_before_run='DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', log_show_after_iter=100, save_ckpt_after_iter=10000, save_ckpt_after_run=False, ), cfg_type='BaseLearnerDict', - load_path='./ckpt_best.pth.tar', # TODO: same like last path. + load_path='./ckpt_best.pth.tar', # TODO: same like last path. # load_path='DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar', ), update_per_collect=10, diff --git a/dizoo/classic_control/cartpole/config/cartpole_bc_config.py b/dizoo/classic_control/cartpole/config/cartpole_bc_config.py index 8315e934fe..b1975718f3 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_bc_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_bc_config.py @@ -20,7 +20,7 @@ batch_size=64, learning_rate=0.01, learner=dict(hook=dict(save_ckpt_after_iter=1000)), - train_epoch = 20, + train_epoch=20, ), eval=dict(evaluator=dict(eval_freq=40, )) ), diff --git a/dizoo/classic_control/mountain_car/config/mtcar_rainbow_config.py b/dizoo/classic_control/mountain_car/config/mtcar_rainbow_config.py index c6c4fb4db0..b293d44494 100644 --- a/dizoo/classic_control/mountain_car/config/mtcar_rainbow_config.py +++ b/dizoo/classic_control/mountain_car/config/mtcar_rainbow_config.py @@ -1,58 +1,63 @@ from easydict import EasyDict # DI-Engine uses EasyDict for configuration, by convention -mtcar_rainbow_config = EasyDict(dict( - exp_name='mtcar_rainbow_seed0', - env=dict( - collector_env_num=8, - evaluator_env_num=5, - n_evaluator_episode=5, - stop_value=195, - ), - policy=dict( - cuda=False, - priority=True, - discount_factor=0.97, - nstep=3, - model=dict( - obs_shape=2, - action_shape=3, - encoder_hidden_size_list=[128, 128, 64], +mtcar_rainbow_config = EasyDict( + dict( + exp_name='mtcar_rainbow_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, ), - learn=dict( - update_per_collect=3, - batch_size=64, - learning_rate=0.001, - target_update_freq=100, + policy=dict( + cuda=False, + priority=True, + discount_factor=0.97, + nstep=3, + model=dict( + obs_shape=2, + action_shape=3, + encoder_hidden_size_list=[128, 128, 64], + ), + learn=dict( + update_per_collect=3, + batch_size=64, + learning_rate=0.001, + target_update_freq=100, + ), + collect=dict( + n_sample=80, + unroll_len=1, + ), + other=dict( + eps=dict( + type='exp', + start=0.95, + end=0.1, + decay=10000, + ), + replay_buffer=dict(replay_buffer_size=20000, ) + ), ), - collect=dict( - n_sample=80, - unroll_len=1, - ), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), replay_buffer=dict(replay_buffer_size=20000, ) - ), - ), -)) + ) +) main_config = mtcar_rainbow_config -mtcar_rainbow_create_config = EasyDict(dict( - env=dict( - type='mountain_car', - import_names=['dizoo.classic_control.mountain_car.envs.mtcar_env'], - ), - env_manager=dict(type='base'), - policy=dict(type='rainbow'), -)) +mtcar_rainbow_create_config = EasyDict( + dict( + env=dict( + type='mountain_car', + import_names=['dizoo.classic_control.mountain_car.envs.mtcar_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='rainbow'), + ) +) create_config = mtcar_rainbow_create_config if __name__ == "__main__": from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/classic_control/mountain_car/envs/__init__.py b/dizoo/classic_control/mountain_car/envs/__init__.py index 19f7eaf1cc..9e8ca86d5f 100644 --- a/dizoo/classic_control/mountain_car/envs/__init__.py +++ b/dizoo/classic_control/mountain_car/envs/__init__.py @@ -1 +1 @@ -from .mtcar_env import MountainCarEnv \ No newline at end of file +from .mtcar_env import MountainCarEnv diff --git a/dizoo/classic_control/pendulum/config/pendulum_ibc_config.py b/dizoo/classic_control/pendulum/config/pendulum_ibc_config.py index 247fdad045..7c56f283fe 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_ibc_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_ibc_config.py @@ -13,16 +13,15 @@ ), policy=dict( cuda=cuda, - model=dict( - obs_shape=3, - action_shape=1, - stochastic_optim=dict(type='mcmc', cuda=cuda,) - ), + model=dict(obs_shape=3, action_shape=1, stochastic_optim=dict( + type='mcmc', + cuda=cuda, + )), learn=dict( multi_gpu=multi_gpu, train_epoch=15, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=1000)), ), collect=dict( @@ -30,7 +29,7 @@ data_path='./pendulum_sac_data_generation/expert_demos.hdf5', collector_logit=False, ), - eval=dict(evaluator=dict(eval_freq=-1,)), + eval=dict(evaluator=dict(eval_freq=-1, )), ), ) pendulum_ibc_config = EasyDict(main_config) diff --git a/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py b/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py index 82a44f034e..8583fc6ada 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py @@ -6,7 +6,7 @@ collector_env_num=8, evaluator_env_num=5, norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), # (bool) Scale output action into legal range. diff --git a/dizoo/classic_control/pendulum/entry/pendulum_dqn_eval.py b/dizoo/classic_control/pendulum/entry/pendulum_dqn_eval.py index a5a7b9ab32..fb80ad42ad 100644 --- a/dizoo/classic_control/pendulum/entry/pendulum_dqn_eval.py +++ b/dizoo/classic_control/pendulum/entry/pendulum_dqn_eval.py @@ -15,8 +15,9 @@ from ding.rl_utils import get_epsilon_greedy_fn from dizoo.classic_control.pendulum.config.pendulum_dqn_config import main_config, create_config + def main(rl_cfg, seed=0): - main_cfg, create_cfg =rl_cfg + main_cfg, create_cfg = rl_cfg cfg = compile_config( main_cfg, BaseEnvManager, @@ -56,4 +57,4 @@ def main(rl_cfg, seed=0): if __name__ == "__main__": - main(rl_cfg=(main_config, create_config),seed=0) + main(rl_cfg=(main_config, create_config), seed=0) diff --git a/dizoo/cliffwalking/config/cliffwalking_dqn_config.py b/dizoo/cliffwalking/config/cliffwalking_dqn_config.py index 974275d6ac..c852858ab7 100644 --- a/dizoo/cliffwalking/config/cliffwalking_dqn_config.py +++ b/dizoo/cliffwalking/config/cliffwalking_dqn_config.py @@ -6,7 +6,7 @@ collector_env_num=8, evaluator_env_num=8, n_evaluator_episode=8, - stop_value=-13, # the optimal value of cliffwalking env + stop_value=-13, # the optimal value of cliffwalking env max_episode_steps=300, ), policy=dict( diff --git a/dizoo/cliffwalking/envs/cliffwalking_env.py b/dizoo/cliffwalking/envs/cliffwalking_env.py index 0d07866574..1bbe5958b4 100644 --- a/dizoo/cliffwalking/envs/cliffwalking_env.py +++ b/dizoo/cliffwalking/envs/cliffwalking_env.py @@ -17,7 +17,7 @@ def __init__(self, cfg: dict) -> None: self._cfg = EasyDict( env_id='CliffWalking', render_mode='rgb_array', - max_episode_steps=300, # default max trajectory length to truncate possible infinite attempts + max_episode_steps=300, # default max trajectory length to truncate possible infinite attempts ) self._cfg.update(cfg) self._init_flag = False diff --git a/dizoo/cliffwalking/envs/test_cliffwalking_env.py b/dizoo/cliffwalking/envs/test_cliffwalking_env.py index e9ead67b36..b378d1a1a8 100644 --- a/dizoo/cliffwalking/envs/test_cliffwalking_env.py +++ b/dizoo/cliffwalking/envs/test_cliffwalking_env.py @@ -2,6 +2,7 @@ import pytest from dizoo.cliffwalking.envs import CliffWalkingEnv + @pytest.mark.envtest class TestCliffWalkingEnv: diff --git a/dizoo/d4rl/config/hopper_medium_expert_bc_config.py b/dizoo/d4rl/config/hopper_medium_expert_bc_config.py index e04bd28069..348361dd2d 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_bc_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_bc_config.py @@ -8,7 +8,7 @@ env=dict( env_id='hopper-medium-expert-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -38,7 +38,7 @@ data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=-1,)), + eval=dict(evaluator=dict(eval_freq=-1, )), ), ) main_config = EasyDict(main_config) @@ -48,7 +48,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='bc', import_names=['ding.policy.bc'], diff --git a/dizoo/d4rl/config/hopper_medium_expert_ibc_ar_config.py b/dizoo/d4rl/config/hopper_medium_expert_ibc_ar_config.py index 061b8b53a6..5d1090dc77 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_ibc_ar_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_ibc_ar_config.py @@ -8,7 +8,7 @@ env=dict( env_id='hopper-medium-expert-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -18,23 +18,19 @@ ), policy=dict( cuda=cuda, - model=dict( - obs_shape=11, - action_shape=3, - stochastic_optim=dict(type='ardfo',) - ), + model=dict(obs_shape=11, action_shape=3, stochastic_optim=dict(type='ardfo', )), learn=dict( multi_gpu=multi_gpu, train_epoch=15, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=1000)), ), collect=dict( data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=-1,)), + eval=dict(evaluator=dict(eval_freq=-1, )), ), ) main_config = EasyDict(main_config) @@ -44,7 +40,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='ibc', import_names=['ding.policy.ibc'], diff --git a/dizoo/d4rl/config/hopper_medium_expert_ibc_config.py b/dizoo/d4rl/config/hopper_medium_expert_ibc_config.py index e7a72984b6..0f040970e6 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_ibc_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_ibc_config.py @@ -8,7 +8,7 @@ env=dict( env_id='hopper-medium-expert-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -18,23 +18,19 @@ ), policy=dict( cuda=cuda, - model=dict( - obs_shape=11, - action_shape=3, - stochastic_optim=dict(type='dfo',) - ), + model=dict(obs_shape=11, action_shape=3, stochastic_optim=dict(type='dfo', )), learn=dict( multi_gpu=multi_gpu, train_epoch=15, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=1000)), ), collect=dict( data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=-1,)), + eval=dict(evaluator=dict(eval_freq=-1, )), ), ) main_config = EasyDict(main_config) @@ -44,7 +40,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='ibc', import_names=['ding.policy.ibc'], diff --git a/dizoo/d4rl/config/hopper_medium_expert_ibc_mcmc_config.py b/dizoo/d4rl/config/hopper_medium_expert_ibc_mcmc_config.py index e5f6f3dbb1..478e0c5d44 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_ibc_mcmc_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_ibc_mcmc_config.py @@ -8,7 +8,7 @@ env=dict( env_id='hopper-medium-expert-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -18,23 +18,19 @@ ), policy=dict( cuda=cuda, - model=dict( - obs_shape=11, - action_shape=3, - stochastic_optim=dict(type='mcmc',) - ), + model=dict(obs_shape=11, action_shape=3, stochastic_optim=dict(type='mcmc', )), learn=dict( multi_gpu=multi_gpu, train_epoch=15, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=1000)), ), collect=dict( data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=-1,)), + eval=dict(evaluator=dict(eval_freq=-1, )), ), ) main_config = EasyDict(main_config) @@ -44,7 +40,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='ibc', import_names=['ding.policy.ibc'], diff --git a/dizoo/d4rl/config/kitchen_complete_bc_config.py b/dizoo/d4rl/config/kitchen_complete_bc_config.py index 7160885da3..413696993d 100644 --- a/dizoo/d4rl/config/kitchen_complete_bc_config.py +++ b/dizoo/d4rl/config/kitchen_complete_bc_config.py @@ -8,7 +8,7 @@ env=dict( env_id='kitchen-complete-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -19,7 +19,7 @@ policy=dict( cuda=cuda, continuous=True, - loss_type='mse_loss', + loss_type='mse_loss', model=dict( obs_shape=60, action_shape=9, @@ -38,7 +38,7 @@ data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=1000,)), + eval=dict(evaluator=dict(eval_freq=1000, )), ), ) main_config = EasyDict(main_config) @@ -48,7 +48,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='bc', import_names=['ding.policy.bc'], diff --git a/dizoo/d4rl/config/kitchen_complete_ibc_ar_config.py b/dizoo/d4rl/config/kitchen_complete_ibc_ar_config.py index 403dc52eff..bbb7198af0 100644 --- a/dizoo/d4rl/config/kitchen_complete_ibc_ar_config.py +++ b/dizoo/d4rl/config/kitchen_complete_ibc_ar_config.py @@ -8,7 +8,7 @@ env=dict( env_id='kitchen-complete-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -18,23 +18,19 @@ ), policy=dict( cuda=cuda, - model=dict( - obs_shape=60, - action_shape=9, - stochastic_optim=dict(type='ardfo',) - ), + model=dict(obs_shape=60, action_shape=9, stochastic_optim=dict(type='ardfo', )), learn=dict( multi_gpu=multi_gpu, train_epoch=1000, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=100)), ), collect=dict( data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=1000,)), + eval=dict(evaluator=dict(eval_freq=1000, )), ), ) main_config = EasyDict(main_config) @@ -44,7 +40,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='ibc', import_names=['ding.policy.ibc'], diff --git a/dizoo/d4rl/config/kitchen_complete_ibc_config.py b/dizoo/d4rl/config/kitchen_complete_ibc_config.py index 5c02f04a81..1606cb7792 100644 --- a/dizoo/d4rl/config/kitchen_complete_ibc_config.py +++ b/dizoo/d4rl/config/kitchen_complete_ibc_config.py @@ -8,7 +8,7 @@ env=dict( env_id='kitchen-complete-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -18,23 +18,19 @@ ), policy=dict( cuda=cuda, - model=dict( - obs_shape=60, - action_shape=9, - stochastic_optim=dict(type='dfo',) - ), + model=dict(obs_shape=60, action_shape=9, stochastic_optim=dict(type='dfo', )), learn=dict( multi_gpu=multi_gpu, train_epoch=1000, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=100)), ), collect=dict( data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=1000,)), + eval=dict(evaluator=dict(eval_freq=1000, )), ), ) main_config = EasyDict(main_config) @@ -44,7 +40,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='ibc', import_names=['ding.policy.ibc'], diff --git a/dizoo/d4rl/config/kitchen_complete_ibc_mcmc_config.py b/dizoo/d4rl/config/kitchen_complete_ibc_mcmc_config.py index d93c5eb737..14924d5257 100644 --- a/dizoo/d4rl/config/kitchen_complete_ibc_mcmc_config.py +++ b/dizoo/d4rl/config/kitchen_complete_ibc_mcmc_config.py @@ -8,7 +8,7 @@ env=dict( env_id='kitchen-complete-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -18,23 +18,19 @@ ), policy=dict( cuda=cuda, - model=dict( - obs_shape=60, - action_shape=9, - stochastic_optim=dict(type='mcmc',) - ), + model=dict(obs_shape=60, action_shape=9, stochastic_optim=dict(type='mcmc', )), learn=dict( multi_gpu=multi_gpu, train_epoch=1000, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=100)), ), collect=dict( data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=1000,)), + eval=dict(evaluator=dict(eval_freq=1000, )), ), ) main_config = EasyDict(main_config) @@ -44,7 +40,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='ibc', import_names=['ding.policy.ibc'], diff --git a/dizoo/d4rl/config/pen_human_bc_config.py b/dizoo/d4rl/config/pen_human_bc_config.py index 6779ffd934..215b706ffc 100644 --- a/dizoo/d4rl/config/pen_human_bc_config.py +++ b/dizoo/d4rl/config/pen_human_bc_config.py @@ -8,7 +8,7 @@ env=dict( env_id='pen-human-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -38,7 +38,7 @@ data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=1000,)), + eval=dict(evaluator=dict(eval_freq=1000, )), ), ) main_config = EasyDict(main_config) @@ -48,7 +48,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='bc', import_names=['ding.policy.bc'], diff --git a/dizoo/d4rl/config/pen_human_ibc_ar_config.py b/dizoo/d4rl/config/pen_human_ibc_ar_config.py index b75e3b9f11..4f59733fd5 100644 --- a/dizoo/d4rl/config/pen_human_ibc_ar_config.py +++ b/dizoo/d4rl/config/pen_human_ibc_ar_config.py @@ -8,7 +8,7 @@ env=dict( env_id='pen-human-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -19,24 +19,20 @@ policy=dict( cuda=cuda, model=dict( - obs_shape=45, - action_shape=24, - hidden_size=128, - hidden_layer_num=4, - stochastic_optim=dict(type='ardfo',) + obs_shape=45, action_shape=24, hidden_size=128, hidden_layer_num=4, stochastic_optim=dict(type='ardfo', ) ), learn=dict( multi_gpu=multi_gpu, train_epoch=1000, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=100)), ), collect=dict( data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=1000,)), + eval=dict(evaluator=dict(eval_freq=1000, )), ), ) main_config = EasyDict(main_config) @@ -46,7 +42,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='ibc', import_names=['ding.policy.ibc'], diff --git a/dizoo/d4rl/config/pen_human_ibc_config.py b/dizoo/d4rl/config/pen_human_ibc_config.py index 207487d921..9ed4f6d17b 100644 --- a/dizoo/d4rl/config/pen_human_ibc_config.py +++ b/dizoo/d4rl/config/pen_human_ibc_config.py @@ -8,7 +8,7 @@ env=dict( env_id='pen-human-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -18,23 +18,19 @@ ), policy=dict( cuda=cuda, - model=dict( - obs_shape=45, - action_shape=24, - stochastic_optim=dict(type='dfo',) - ), + model=dict(obs_shape=45, action_shape=24, stochastic_optim=dict(type='dfo', )), learn=dict( multi_gpu=multi_gpu, train_epoch=1000, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=100)), ), collect=dict( data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=1000,)), + eval=dict(evaluator=dict(eval_freq=1000, )), ), ) main_config = EasyDict(main_config) @@ -44,7 +40,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='ibc', import_names=['ding.policy.ibc'], diff --git a/dizoo/d4rl/config/pen_human_ibc_mcmc_config.py b/dizoo/d4rl/config/pen_human_ibc_mcmc_config.py index cee0f631fd..4dd6b37f90 100644 --- a/dizoo/d4rl/config/pen_human_ibc_mcmc_config.py +++ b/dizoo/d4rl/config/pen_human_ibc_mcmc_config.py @@ -8,7 +8,7 @@ env=dict( env_id='pen-human-v0', norm_obs=dict( - use_norm=True, + use_norm=True, offline_stats=dict(use_offline_stats=True, ), ), evaluator_env_num=8, @@ -18,23 +18,19 @@ ), policy=dict( cuda=cuda, - model=dict( - obs_shape=45, - action_shape=24, - stochastic_optim=dict(type='mcmc',) - ), + model=dict(obs_shape=45, action_shape=24, stochastic_optim=dict(type='mcmc', )), learn=dict( multi_gpu=multi_gpu, train_epoch=1000, batch_size=256, - optim=dict(learning_rate=1e-5,), + optim=dict(learning_rate=1e-5, ), learner=dict(hook=dict(log_show_after_iter=100)), ), collect=dict( data_type='d4rl', data_path=None, ), - eval=dict(evaluator=dict(eval_freq=1000,)), + eval=dict(evaluator=dict(eval_freq=1000, )), ), ) main_config = EasyDict(main_config) @@ -44,7 +40,7 @@ type='d4rl', import_names=['dizoo.d4rl.envs.d4rl_env'], ), - env_manager=dict(type='base',), + env_manager=dict(type='base', ), policy=dict( type='ibc', import_names=['ding.policy.ibc'], diff --git a/dizoo/d4rl/entry/d4rl_cql_main.py b/dizoo/d4rl/entry/d4rl_cql_main.py index 9315a3644d..7a8934a90a 100644 --- a/dizoo/d4rl/entry/d4rl_cql_main.py +++ b/dizoo/d4rl/entry/d4rl_cql_main.py @@ -5,7 +5,7 @@ def train(args): # launch from anywhere - config = Path(__file__).absolute().parent.parent / 'config' / args.config + config = Path(__file__).absolute().parent.parent / 'config' / args.config config = read_config(str(config)) config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) serial_pipeline_offline(config, seed=args.seed) diff --git a/dizoo/d4rl/entry/d4rl_dt_mujoco.py b/dizoo/d4rl/entry/d4rl_dt_mujoco.py index 9f176b353b..937e8987bf 100644 --- a/dizoo/d4rl/entry/d4rl_dt_mujoco.py +++ b/dizoo/d4rl/entry/d4rl_dt_mujoco.py @@ -24,7 +24,8 @@ def main(): # ding_init(cfg) with task.start(async_mode=False, ctx=OfflineRLContext()): evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: AllinObsWrapper(D4RLEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + env_fn=[lambda: AllinObsWrapper(D4RLEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager ) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) @@ -32,7 +33,8 @@ def main(): dataset = create_dataset(cfg) # env_data_stats = dataset.get_d4rl_dataset_stats(cfg.policy.dataset_name) env_data_stats = dataset.get_state_stats() - cfg.policy.state_mean, cfg.policy.state_std = np.array(env_data_stats['state_mean']), np.array(env_data_stats['state_std']) + cfg.policy.state_mean, cfg.policy.state_std = np.array(env_data_stats['state_mean'] + ), np.array(env_data_stats['state_std']) model = DecisionTransformer(**cfg.policy.model) policy = DTPolicy(cfg.policy, model=model) task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) @@ -46,4 +48,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/dizoo/d4rl/entry/d4rl_td3_bc_main.py b/dizoo/d4rl/entry/d4rl_td3_bc_main.py index bdf945978f..b25bf904a5 100644 --- a/dizoo/d4rl/entry/d4rl_td3_bc_main.py +++ b/dizoo/d4rl/entry/d4rl_td3_bc_main.py @@ -5,7 +5,7 @@ def train(args): # launch from anywhere - config = Path(__file__).absolute().parent.parent / 'config' / args.config + config = Path(__file__).absolute().parent.parent / 'config' / args.config config = read_config(str(config)) config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) serial_pipeline_offline(config, seed=args.seed) diff --git a/dizoo/dmc2gym/config/dmc2gym_ppo_config.py b/dizoo/dmc2gym/config/dmc2gym_ppo_config.py index 4f48633c5f..207b398e63 100644 --- a/dizoo/dmc2gym/config/dmc2gym_ppo_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_ppo_config.py @@ -1,6 +1,5 @@ from easydict import EasyDict - cartpole_balance_ppo_config = dict( exp_name='dmc2gym_cartpole_balance_ppo', env=dict( diff --git a/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py b/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py index 60a83921ef..1f6eb2abb5 100644 --- a/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py +++ b/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py @@ -15,6 +15,7 @@ from dizoo.dmc2gym.envs.dmc2gym_env import DMC2GymEnv from dizoo.dmc2gym.config.dmc2gym_sac_pixel_config import main_config, create_config + def main(): logging.getLogger().setLevel(logging.INFO) main_config.exp_name = 'dmc2gym_sac_pixel_seed0' @@ -23,8 +24,8 @@ def main(): num_seed = 1 for seed_i in range(num_seed): - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed'+str(seed_i))) - + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i))) + with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager @@ -42,16 +43,20 @@ def main(): def _add_scalar(ctx): if ctx.eval_value != -np.inf: - tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step= ctx.env_step) + tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step=ctx.env_step) collector_rewards = [ctx.trajectories[i]['reward'] for i in range(len(ctx.trajectories))] collector_mean_reward = sum(collector_rewards) / len(ctx.trajectories) # collector_max_reward = max(collector_rewards) # collector_min_reward = min(collector_rewards) - tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step= ctx.env_step) + tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step=ctx.env_step) # tb_logger.add_scalar('collecter_step/max_reward', collector_max_reward, global_step= ctx.env_step) # tb_logger.add_scalar('collecter_step/min_reward', collector_min_reward, global_step= ctx.env_step) - tb_logger.add_scalar('collecter_step/avg_env_step_per_episode', ctx.env_step/ctx.env_episode, global_step= ctx.env_step) - + tb_logger.add_scalar( + 'collecter_step/avg_env_step_per_episode', + ctx.env_step / ctx.env_episode, + global_step=ctx.env_step + ) + def _add_train_scalar(ctx): len_train = len(ctx.train_output) cur_lr_q_avg = sum([ctx.train_output[i]['cur_lr_q'] for i in range(len_train)]) / len_train @@ -59,15 +64,17 @@ def _add_train_scalar(ctx): critic_loss_avg = sum([ctx.train_output[i]['critic_loss'] for i in range(len_train)]) / len_train policy_loss_avg = sum([ctx.train_output[i]['policy_loss'] for i in range(len_train)]) / len_train total_loss_avg = sum([ctx.train_output[i]['total_loss'] for i in range(len_train)]) / len_train - tb_logger.add_scalar('learner_step/cur_lr_q_avg', cur_lr_q_avg, global_step= ctx.env_step) - tb_logger.add_scalar('learner_step/cur_lr_p_avg', cur_lr_p_avg, global_step= ctx.env_step) - tb_logger.add_scalar('learner_step/critic_loss_avg', critic_loss_avg, global_step= ctx.env_step) - tb_logger.add_scalar('learner_step/policy_loss_avg', policy_loss_avg, global_step= ctx.env_step) - tb_logger.add_scalar('learner_step/total_loss_avg', total_loss_avg, global_step= ctx.env_step) - + tb_logger.add_scalar('learner_step/cur_lr_q_avg', cur_lr_q_avg, global_step=ctx.env_step) + tb_logger.add_scalar('learner_step/cur_lr_p_avg', cur_lr_p_avg, global_step=ctx.env_step) + tb_logger.add_scalar('learner_step/critic_loss_avg', critic_loss_avg, global_step=ctx.env_step) + tb_logger.add_scalar('learner_step/policy_loss_avg', policy_loss_avg, global_step=ctx.env_step) + tb_logger.add_scalar('learner_step/total_loss_avg', total_loss_avg, global_step=ctx.env_step) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use( - StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size) + StepCollector( + cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size + ) ) task.use(_add_scalar) task.use(data_pusher(cfg, buffer_)) diff --git a/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py b/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py index 6bc7036352..7e6cf920f5 100644 --- a/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py +++ b/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py @@ -15,6 +15,7 @@ from tensorboardX import SummaryWriter import os + def main(): logging.getLogger().setLevel(logging.INFO) main_config.exp_name = 'dmc2gym_sac_state_nseed_5M' @@ -23,8 +24,8 @@ def main(): num_seed = 4 for seed_i in range(num_seed): - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed'+str(seed_i))) - + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i))) + with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager @@ -41,16 +42,20 @@ def main(): def _add_scalar(ctx): if ctx.eval_value != -np.inf: - tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step= ctx.env_step) + tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step=ctx.env_step) collector_rewards = [ctx.trajectories[i]['reward'] for i in range(len(ctx.trajectories))] collector_mean_reward = sum(collector_rewards) / len(ctx.trajectories) # collector_max_reward = max(collector_rewards) # collector_min_reward = min(collector_rewards) - tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step= ctx.env_step) + tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step=ctx.env_step) # tb_logger.add_scalar('collecter_step/max_reward', collector_max_reward, global_step= ctx.env_step) # tb_logger.add_scalar('collecter_step/min_reward', collector_min_reward, global_step= ctx.env_step) - tb_logger.add_scalar('collecter_step/avg_env_step_per_episode', ctx.env_step/ctx.env_episode, global_step= ctx.env_step) - + tb_logger.add_scalar( + 'collecter_step/avg_env_step_per_episode', + ctx.env_step / ctx.env_episode, + global_step=ctx.env_step + ) + def _add_train_scalar(ctx): len_train = len(ctx.train_output) cur_lr_q_avg = sum([ctx.train_output[i]['cur_lr_q'] for i in range(len_train)]) / len_train @@ -58,15 +63,17 @@ def _add_train_scalar(ctx): critic_loss_avg = sum([ctx.train_output[i]['critic_loss'] for i in range(len_train)]) / len_train policy_loss_avg = sum([ctx.train_output[i]['policy_loss'] for i in range(len_train)]) / len_train total_loss_avg = sum([ctx.train_output[i]['total_loss'] for i in range(len_train)]) / len_train - tb_logger.add_scalar('learner_step/cur_lr_q_avg', cur_lr_q_avg, global_step= ctx.env_step) - tb_logger.add_scalar('learner_step/cur_lr_p_avg', cur_lr_p_avg, global_step= ctx.env_step) - tb_logger.add_scalar('learner_step/critic_loss_avg', critic_loss_avg, global_step= ctx.env_step) - tb_logger.add_scalar('learner_step/policy_loss_avg', policy_loss_avg, global_step= ctx.env_step) - tb_logger.add_scalar('learner_step/total_loss_avg', total_loss_avg, global_step= ctx.env_step) - + tb_logger.add_scalar('learner_step/cur_lr_q_avg', cur_lr_q_avg, global_step=ctx.env_step) + tb_logger.add_scalar('learner_step/cur_lr_p_avg', cur_lr_p_avg, global_step=ctx.env_step) + tb_logger.add_scalar('learner_step/critic_loss_avg', critic_loss_avg, global_step=ctx.env_step) + tb_logger.add_scalar('learner_step/policy_loss_avg', policy_loss_avg, global_step=ctx.env_step) + tb_logger.add_scalar('learner_step/total_loss_avg', total_loss_avg, global_step=ctx.env_step) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use( - StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size) + StepCollector( + cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size + ) ) task.use(_add_scalar) task.use(data_pusher(cfg, buffer_)) diff --git a/dizoo/dmc2gym/envs/dmc2gym_env.py b/dizoo/dmc2gym/envs/dmc2gym_env.py index 9e97629897..14c70b6f44 100644 --- a/dizoo/dmc2gym/envs/dmc2gym_env.py +++ b/dizoo/dmc2gym/envs/dmc2gym_env.py @@ -10,6 +10,7 @@ def dmc2gym_observation_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Callable: + def observation_space(from_pixels=True, height=84, width=84, channels_first=True) -> Box: if from_pixels: shape = [3, height, width] if channels_first else [height, width, 3] @@ -29,6 +30,7 @@ def dmc2gym_action_space(dim, minimum=-1, maximum=1, dtype=np.float32) -> Box: def dmc2gym_reward_space(minimum=0, maximum=1, dtype=np.float32) -> Callable: + def reward_space(frame_skip=1) -> Box: return Box( np.repeat(minimum * frame_skip, 1).astype(dtype), diff --git a/dizoo/dmc2gym/envs/test_dmc2gym_env.py b/dizoo/dmc2gym/envs/test_dmc2gym_env.py index 94e6d9e9a7..5245a7a86a 100644 --- a/dizoo/dmc2gym/envs/test_dmc2gym_env.py +++ b/dizoo/dmc2gym/envs/test_dmc2gym_env.py @@ -47,4 +47,3 @@ def test_naive(self): assert timestep.reward <= env.reward_space.high print(env.observation_space, env.action_space, env.reward_space) env.close() - diff --git a/dizoo/evogym/envs/test/visualize_simple_env.py b/dizoo/evogym/envs/test/visualize_simple_env.py index cde80b725c..2203209fbe 100644 --- a/dizoo/evogym/envs/test/visualize_simple_env.py +++ b/dizoo/evogym/envs/test/visualize_simple_env.py @@ -7,7 +7,6 @@ from dizoo.evogym.envs.viewer import DingEvoViewer from evogym.sim import EvoSim - if __name__ == '__main__': gym.logger.set_level(gym.logger.DEBUG) # create a random robot diff --git a/dizoo/gym_anytrading/config/stocks_dqn_config.py b/dizoo/gym_anytrading/config/stocks_dqn_config.py index c16ab0a5a5..c05a1f5974 100644 --- a/dizoo/gym_anytrading/config/stocks_dqn_config.py +++ b/dizoo/gym_anytrading/config/stocks_dqn_config.py @@ -78,13 +78,11 @@ import_names=['dizoo.gym_anytrading.envs.stocks_env'], ), env_manager=dict(type='base'), - policy=dict( - type='dqn', - ), + policy=dict(type='dqn', ), evaluator=dict( type='trading_interaction', import_names=['dizoo.gym_anytrading.worker'], - ), + ), ) stocks_dqn_create_config = EasyDict(stocks_dqn_create_config) create_config = stocks_dqn_create_config diff --git a/dizoo/gym_anytrading/worker/trading_serial_evaluator.py b/dizoo/gym_anytrading/worker/trading_serial_evaluator.py index 9c7749f722..d2fa4d22d1 100644 --- a/dizoo/gym_anytrading/worker/trading_serial_evaluator.py +++ b/dizoo/gym_anytrading/worker/trading_serial_evaluator.py @@ -32,13 +32,13 @@ class TradingSerialEvaluator(InteractionSerialEvaluator): ) def __init__( - self, - cfg: dict, - env: BaseEnvManager = None, - policy: namedtuple = None, - tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'evaluator', + self, + cfg: dict, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'evaluator', ) -> None: """ Overview: @@ -49,12 +49,12 @@ def __init__( super().__init__(cfg, env, policy, tb_logger, exp_name, instance_name) def eval( - self, - save_ckpt_fn: Callable = None, - train_iter: int = -1, - envstep: int = -1, - n_episode: Optional[int] = None, - force_render: bool = False, + self, + save_ckpt_fn: Callable = None, + train_iter: int = -1, + envstep: int = -1, + n_episode: Optional[int] = None, + force_render: bool = False, ) -> Tuple[bool, dict]: ''' Overview: diff --git a/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/__init__.py b/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/__init__.py index aa9f5bdf37..89cb5d7764 100644 --- a/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/__init__.py +++ b/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/__init__.py @@ -3,7 +3,6 @@ from gym_hybrid.environments import SlidingEnv from gym_hybrid.environments import HardMoveEnv - register( id='Moving-v0', entry_point='gym_hybrid:MovingEnv', @@ -15,4 +14,4 @@ register( id='HardMove-v0', entry_point='gym_hybrid:HardMoveEnv', -) \ No newline at end of file +) diff --git a/dizoo/gym_hybrid/envs/gym-hybrid/setup.py b/dizoo/gym_hybrid/envs/gym-hybrid/setup.py index af82deb670..248ccb4535 100644 --- a/dizoo/gym_hybrid/envs/gym-hybrid/setup.py +++ b/dizoo/gym_hybrid/envs/gym-hybrid/setup.py @@ -1,7 +1,8 @@ from setuptools import setup -setup(name='gym_hybrid', - version='0.0.2', # original gym_hybrid version='0.0.1' - packages=['gym_hybrid'], - install_requires=['gym', 'numpy'], +setup( + name='gym_hybrid', + version='0.0.2', # original gym_hybrid version='0.0.1' + packages=['gym_hybrid'], + install_requires=['gym', 'numpy'], ) diff --git a/dizoo/gym_hybrid/envs/gym-hybrid/tests/moving.py b/dizoo/gym_hybrid/envs/gym-hybrid/tests/moving.py index dbc230c0d7..52315decd9 100644 --- a/dizoo/gym_hybrid/envs/gym-hybrid/tests/moving.py +++ b/dizoo/gym_hybrid/envs/gym-hybrid/tests/moving.py @@ -2,7 +2,6 @@ import gym import gym_hybrid - if __name__ == '__main__': env = gym.make('Moving-v0') env.reset() diff --git a/dizoo/gym_hybrid/envs/test_gym_hybrid_env.py b/dizoo/gym_hybrid/envs/test_gym_hybrid_env.py index 7a7bc10006..896987f33f 100644 --- a/dizoo/gym_hybrid/envs/test_gym_hybrid_env.py +++ b/dizoo/gym_hybrid/envs/test_gym_hybrid_env.py @@ -8,7 +8,17 @@ class TestGymHybridEnv: def test_naive(self): - env = GymHybridEnv(EasyDict({'env_id': 'Moving-v0', 'act_scale': False, 'save_replay_gif': False, 'replay_path_gif': None, 'replay_path': None})) + env = GymHybridEnv( + EasyDict( + { + 'env_id': 'Moving-v0', + 'act_scale': False, + 'save_replay_gif': False, + 'replay_path_gif': None, + 'replay_path': None + } + ) + ) env.enable_save_replay('./video') env.seed(314, dynamic_seed=False) assert env._seed == 314 diff --git a/dizoo/image_classification/entry/imagenet_res18_config.py b/dizoo/image_classification/entry/imagenet_res18_config.py index 970ea4f2fd..bd4f473dd6 100644 --- a/dizoo/image_classification/entry/imagenet_res18_config.py +++ b/dizoo/image_classification/entry/imagenet_res18_config.py @@ -27,9 +27,7 @@ learn_data_path='/mnt/lustre/share/images/train', eval_data_path='/mnt/lustre/share/images/val', ), - eval=dict( - batch_size=32, evaluator=dict(eval_freq=1, stop_value=dict(loss=0.5, acc1=75.0, acc5=95.0)) - ), + eval=dict(batch_size=32, evaluator=dict(eval_freq=1, stop_value=dict(loss=0.5, acc1=75.0, acc5=95.0))), ), env=dict(), ) diff --git a/dizoo/league_demo/league_demo_collector.py b/dizoo/league_demo/league_demo_collector.py index 211e15b5e8..ce7985a6dc 100644 --- a/dizoo/league_demo/league_demo_collector.py +++ b/dizoo/league_demo/league_demo_collector.py @@ -25,13 +25,13 @@ class LeagueDemoCollector(ISerialCollector): config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False) def __init__( - self, - cfg: EasyDict, - env: BaseEnvManager = None, - policy: List[namedtuple] = None, - tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'collector' + self, + cfg: EasyDict, + env: BaseEnvManager = None, + policy: List[namedtuple] = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'collector' ) -> None: """ Overview: diff --git a/dizoo/maze/entry/maze_bc_main.py b/dizoo/maze/entry/maze_bc_main.py index efd9b6d2a8..3a42d4e921 100644 --- a/dizoo/maze/entry/maze_bc_main.py +++ b/dizoo/maze/entry/maze_bc_main.py @@ -61,9 +61,7 @@ def get_vi_sequence(env, observation): cur_x, cur_y = start_x, start_y while cur_x != target_location[0] or cur_y != target_location[1]: act = vi_sequence[-1][cur_x, cur_y] - track_back.append(( - torch.FloatTensor(env.process_states([cur_x, cur_y], env.get_maze_map())), - act)) + track_back.append((torch.FloatTensor(env.process_states([cur_x, cur_y], env.get_maze_map())), act)) if act == 0: cur_x += 1 elif act == 1: @@ -89,6 +87,7 @@ def __len__(self): def load_bc_dataset(train_seeds=1, test_seeds=1, batch_size=32): + def load_env(seed): ccc = easydict.EasyDict({'size': 16}) e = Maze(ccc) @@ -111,13 +110,8 @@ def load_env(seed): data += track_back - - train_data = BCDataset( - data_train - ) - test_data = BCDataset( - data_test - ) + train_data = BCDataset(data_train) + test_data = BCDataset(data_test) train_dataset = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_dataset = DataLoader(test_data, batch_size=batch_size, shuffle=True) diff --git a/dizoo/minigrid/utils/eval.py b/dizoo/minigrid/utils/eval.py index e8e4f728fa..e3c6acb9fb 100644 --- a/dizoo/minigrid/utils/eval.py +++ b/dizoo/minigrid/utils/eval.py @@ -8,11 +8,11 @@ def eval( - input_cfg: Union[str, Tuple[dict, dict]], - seed: int = 0, - model: Optional[torch.nn.Module] = None, - state_dict: Optional[dict] = None, - replay_path: Optional[str] = './video', + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + state_dict: Optional[dict] = None, + replay_path: Optional[str] = './video', ) -> float: r""" Overview: diff --git a/dizoo/mujoco/config/halfcheetah_bdq_config.py b/dizoo/mujoco/config/halfcheetah_bdq_config.py index 145bf8062e..25fb65ba35 100644 --- a/dizoo/mujoco/config/halfcheetah_bdq_config.py +++ b/dizoo/mujoco/config/halfcheetah_bdq_config.py @@ -22,7 +22,6 @@ action_bins_per_branch=2, # mean the action shape is 6, 2 discrete actions for each action dimension encoder_hidden_size_list=[256, 256, 128], ), - learn=dict( batch_size=512, learning_rate=3e-4, @@ -65,4 +64,8 @@ if __name__ == "__main__": # or you can enter `ding -m serial_onpolicy -c halfcheetah_onbdq_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0, max_env_step=10000000,) \ No newline at end of file + serial_pipeline( + (main_config, create_config), + seed=0, + max_env_step=10000000, + ) diff --git a/dizoo/mujoco/config/hopper_bdq_config.py b/dizoo/mujoco/config/hopper_bdq_config.py index de08da2a7a..34dbe21664 100644 --- a/dizoo/mujoco/config/hopper_bdq_config.py +++ b/dizoo/mujoco/config/hopper_bdq_config.py @@ -68,4 +68,8 @@ if __name__ == "__main__": # or you can enter `ding -m serial_onpolicy -c hopper_bdq_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline([main_config, create_config], seed=0, max_env_step=10000000,) + serial_pipeline( + [main_config, create_config], + seed=0, + max_env_step=10000000, + ) diff --git a/dizoo/mujoco/envs/mujoco_wrappers.py b/dizoo/mujoco/envs/mujoco_wrappers.py index 8fc19cd503..d99819783c 100644 --- a/dizoo/mujoco/envs/mujoco_wrappers.py +++ b/dizoo/mujoco/envs/mujoco_wrappers.py @@ -6,10 +6,10 @@ def wrap_mujoco( - env_id, - norm_obs: Dict = dict(use_norm=False, ), - norm_reward: Dict = dict(use_norm=False, ), - delay_reward_step: int = 1 + env_id, + norm_obs: Dict = dict(use_norm=False, ), + norm_reward: Dict = dict(use_norm=False, ), + delay_reward_step: int = 1 ) -> gym.Env: r""" Overview: diff --git a/dizoo/multiagent_mujoco/config/ant_mappo_config.py b/dizoo/multiagent_mujoco/config/ant_mappo_config.py index f221fa7c0f..d11c31be8d 100644 --- a/dizoo/multiagent_mujoco/config/ant_mappo_config.py +++ b/dizoo/multiagent_mujoco/config/ant_mappo_config.py @@ -75,7 +75,6 @@ ) create_config = EasyDict(create_config) - if __name__ == '__main__': from ding.entry import serial_pipeline_onpolicy serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(1e7)) diff --git a/dizoo/multiagent_mujoco/config/ant_masac_config.py b/dizoo/multiagent_mujoco/config/ant_masac_config.py index 1f04efe8b7..9316b095c0 100644 --- a/dizoo/multiagent_mujoco/config/ant_masac_config.py +++ b/dizoo/multiagent_mujoco/config/ant_masac_config.py @@ -34,9 +34,7 @@ target_theta=0.005, discount_factor=0.99, ), - collect=dict( - n_sample=400, - ), + collect=dict(n_sample=400, ), eval=dict(evaluator=dict(eval_freq=500, )), other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ), ), diff --git a/dizoo/petting_zoo/config/ptz_simple_spread_madqn_config.py b/dizoo/petting_zoo/config/ptz_simple_spread_madqn_config.py index b7db69abbe..8ddb636abf 100644 --- a/dizoo/petting_zoo/config/ptz_simple_spread_madqn_config.py +++ b/dizoo/petting_zoo/config/ptz_simple_spread_madqn_config.py @@ -41,9 +41,7 @@ discount_factor=0.95, ), collect=dict( - collector=dict( - get_train_sample=True, - ), + collector=dict(get_train_sample=True, ), n_episode=32, unroll_len=10, env_num=collector_env_num, @@ -60,9 +58,7 @@ end=0.05, decay=10000, ), - replay_buffer=dict( - replay_buffer_size=15000, - ), + replay_buffer=dict(replay_buffer_size=15000, ), ), ), ) diff --git a/dizoo/rocket/entry/rocket_hover_ppo_main.py b/dizoo/rocket/entry/rocket_hover_ppo_main.py index 2539ff12d3..13f5714483 100644 --- a/dizoo/rocket/entry/rocket_hover_ppo_main.py +++ b/dizoo/rocket/entry/rocket_hover_ppo_main.py @@ -30,12 +30,10 @@ def main(): tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i))) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], - cfg=cfg.env.manager + env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], - cfg=cfg.env.manager + env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) # evaluator_env.enable_save_replay() diff --git a/dizoo/rocket/entry/rocket_landing_ppo_main.py b/dizoo/rocket/entry/rocket_landing_ppo_main.py index cc83242ce5..bf8ebb5162 100644 --- a/dizoo/rocket/entry/rocket_landing_ppo_main.py +++ b/dizoo/rocket/entry/rocket_landing_ppo_main.py @@ -27,15 +27,13 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) num_seed = 4 for seed_i in range(num_seed): - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed'+str(seed_i))) + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i))) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], - cfg=cfg.env.manager + env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], - cfg=cfg.env.manager + env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) # evaluator_env.enable_save_replay() diff --git a/dizoo/rocket/envs/test_rocket_env.py b/dizoo/rocket/envs/test_rocket_env.py index e19d2879c1..a8bf030fe7 100644 --- a/dizoo/rocket/envs/test_rocket_env.py +++ b/dizoo/rocket/envs/test_rocket_env.py @@ -12,7 +12,7 @@ def test_hover(self): env.seed(314, dynamic_seed=False) assert env._seed == 314 obs = env.reset() - assert obs.shape == (8,) + assert obs.shape == (8, ) for _ in range(5): env.reset() np.random.seed(314) @@ -28,8 +28,8 @@ def test_hover(self): print('timestep', timestep, '\n') assert isinstance(timestep.obs, np.ndarray) assert isinstance(timestep.done, bool) - assert timestep.obs.shape == (8,) - assert timestep.reward.shape == (1,) + assert timestep.obs.shape == (8, ) + assert timestep.reward.shape == (1, ) assert timestep.reward >= env.reward_space.low assert timestep.reward <= env.reward_space.high print(env.observation_space, env.action_space, env.reward_space) diff --git a/dizoo/smac/config/smac_3s5z_madqn_config.py b/dizoo/smac/config/smac_3s5z_madqn_config.py index c15dfcd655..5e771baf09 100644 --- a/dizoo/smac/config/smac_3s5z_madqn_config.py +++ b/dizoo/smac/config/smac_3s5z_madqn_config.py @@ -18,9 +18,7 @@ stop_value=0.999, n_evaluator_episode=32, special_global_state=True, - manager=dict( - shared_memory=False, - ), + manager=dict(shared_memory=False, ), ), policy=dict( nstep=1, @@ -41,9 +39,7 @@ discount_factor=0.95, ), collect=dict( - collector=dict( - get_train_sample=True, - ), + collector=dict(get_train_sample=True, ), n_episode=32, unroll_len=10, env_num=collector_env_num, @@ -56,9 +52,7 @@ end=0.05, decay=10000, ), - replay_buffer=dict( - replay_buffer_size=15000, - ), + replay_buffer=dict(replay_buffer_size=15000, ), ), ), ) diff --git a/dizoo/smac/config/smac_3s5zvs3s6z_madqn_config.py b/dizoo/smac/config/smac_3s5zvs3s6z_madqn_config.py index 23c215b63c..438025241f 100644 --- a/dizoo/smac/config/smac_3s5zvs3s6z_madqn_config.py +++ b/dizoo/smac/config/smac_3s5zvs3s6z_madqn_config.py @@ -18,9 +18,7 @@ stop_value=0.999, n_evaluator_episode=32, special_global_state=True, - manager=dict( - shared_memory=False, - ), + manager=dict(shared_memory=False, ), ), policy=dict( nstep=3, @@ -41,9 +39,7 @@ discount_factor=0.95, ), collect=dict( - collector=dict( - get_train_sample=True, - ), + collector=dict(get_train_sample=True, ), n_episode=32, unroll_len=10, env_num=collector_env_num, @@ -56,9 +52,7 @@ end=0.05, decay=100000, ), - replay_buffer=dict( - replay_buffer_size=30000, - ), + replay_buffer=dict(replay_buffer_size=30000, ), ), ), ) diff --git a/dizoo/smac/config/smac_5m6m_madqn_config.py b/dizoo/smac/config/smac_5m6m_madqn_config.py index 0aa0497712..d05bb23dcb 100644 --- a/dizoo/smac/config/smac_5m6m_madqn_config.py +++ b/dizoo/smac/config/smac_5m6m_madqn_config.py @@ -27,7 +27,7 @@ obs_shape=72, global_obs_shape=152, action_shape=12, - hidden_size_list=[256,256], + hidden_size_list=[256, 256], ), learn=dict( update_per_collect=40, @@ -38,9 +38,7 @@ discount_factor=0.95, ), collect=dict( - collector=dict( - get_train_sample=True, - ), + collector=dict(get_train_sample=True, ), n_episode=32, unroll_len=10, env_num=collector_env_num, @@ -53,9 +51,7 @@ end=0.05, decay=50000, ), - replay_buffer=dict( - replay_buffer_size=50000, - ), + replay_buffer=dict(replay_buffer_size=50000, ), ), ), ) @@ -87,7 +83,6 @@ def train(args): train(args) - def train(args): config = [main_config, create_config] serial_pipeline(config, seed=args.seed, max_env_step=1e7) diff --git a/dizoo/smac/config/smac_8m9m_madqn_config.py b/dizoo/smac/config/smac_8m9m_madqn_config.py index ccf9153a14..672330df24 100644 --- a/dizoo/smac/config/smac_8m9m_madqn_config.py +++ b/dizoo/smac/config/smac_8m9m_madqn_config.py @@ -27,7 +27,7 @@ obs_shape=108, global_obs_shape=263, action_shape=15, - hidden_size_list=[256,256], + hidden_size_list=[256, 256], ), learn=dict( update_per_collect=40, @@ -38,9 +38,7 @@ discount_factor=0.95, ), collect=dict( - collector=dict( - get_train_sample=True, - ), + collector=dict(get_train_sample=True, ), n_episode=32, unroll_len=20, env_num=collector_env_num, @@ -53,9 +51,7 @@ end=0.05, decay=50000, ), - replay_buffer=dict( - replay_buffer_size=20000, - ), + replay_buffer=dict(replay_buffer_size=20000, ), ), ), ) @@ -87,7 +83,6 @@ def train(args): train(args) - def train(args): config = [main_config, create_config] serial_pipeline(config, seed=args.seed, max_env_step=1e7) diff --git a/dizoo/smac/config/smac_MMM2_madqn_config.py b/dizoo/smac/config/smac_MMM2_madqn_config.py index 60e3123dc4..fe8e96501c 100644 --- a/dizoo/smac/config/smac_MMM2_madqn_config.py +++ b/dizoo/smac/config/smac_MMM2_madqn_config.py @@ -18,9 +18,7 @@ stop_value=0.999, n_evaluator_episode=32, special_global_state=True, - manager=dict( - shared_memory=False, - ), + manager=dict(shared_memory=False, ), ), policy=dict( nstep=1, @@ -41,9 +39,7 @@ discount_factor=0.95, ), collect=dict( - collector=dict( - get_train_sample=True, - ), + collector=dict(get_train_sample=True, ), n_episode=32, unroll_len=20, env_num=collector_env_num, @@ -56,9 +52,7 @@ end=0.05, decay=100000, ), - replay_buffer=dict( - replay_buffer_size=30000, - ), + replay_buffer=dict(replay_buffer_size=30000, ), ), ), ) diff --git a/dizoo/smac/config/smac_MMM_madqn_config.py b/dizoo/smac/config/smac_MMM_madqn_config.py index 1d9a6abeaf..892f1f5217 100644 --- a/dizoo/smac/config/smac_MMM_madqn_config.py +++ b/dizoo/smac/config/smac_MMM_madqn_config.py @@ -18,9 +18,7 @@ stop_value=0.999, n_evaluator_episode=32, special_global_state=True, - manager=dict( - shared_memory=False, - ), + manager=dict(shared_memory=False, ), ), policy=dict( nstep=1, @@ -41,9 +39,7 @@ discount_factor=0.95, ), collect=dict( - collector=dict( - get_train_sample=True, - ), + collector=dict(get_train_sample=True, ), n_episode=32, unroll_len=10, env_num=collector_env_num, @@ -56,9 +52,7 @@ end=0.05, decay=10000, ), - replay_buffer=dict( - replay_buffer_size=15000, - ), + replay_buffer=dict(replay_buffer_size=15000, ), ), ), ) diff --git a/dizoo/smac/utils/eval.py b/dizoo/smac/utils/eval.py index 6d683a8ace..1e112e84a7 100644 --- a/dizoo/smac/utils/eval.py +++ b/dizoo/smac/utils/eval.py @@ -10,11 +10,11 @@ def eval( - input_cfg: Union[str, Tuple[dict, dict]], - seed: int = 0, - env_setting: Optional[List[Any]] = None, - model: Optional[torch.nn.Module] = None, - state_dict: Optional[dict] = None, + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + state_dict: Optional[dict] = None, ) -> float: r""" Overview: From 9bd641f7508b885bf2d0c47bb90b19c2182e9a9e Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 7 Aug 2023 15:46:29 +0800 Subject: [PATCH 11/25] Reformat --- ding/framework/middleware/functional/__init__.py | 4 ++-- ding/framework/middleware/functional/data_processor.py | 7 ++++--- ding/model/template/dt.py | 6 +++--- ding/utils/data/dataset.py | 9 ++------- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 6e3eedce3e..4b0817ee48 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -1,6 +1,6 @@ from .trainer import trainer, multistep_trainer -from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, offline_data_fetcher_from_mem, \ - sqil_data_pusher, buffer_saver +from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \ + offline_data_fetcher_from_mem, sqil_data_pusher, buffer_saver from .collector import inferencer, rolloutor, TransitionList from .evaluator import interaction_evaluator, interaction_evaluator_ttorch from .termination_checker import termination_checker, ddp_termination_checker diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 0c008370c0..7c31ff2166 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -6,6 +6,7 @@ from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type from ding.data.buffer.middleware import PriorityExperienceReplay from ding.framework import task +from ding.utils import get_rank if TYPE_CHECKING: from ding.framework import OnlineRLContext, OfflineRLContext @@ -208,9 +209,10 @@ def producer(queue, dataset, batch_size, device): queue.put(data) queue = Queue(maxsize=50) + device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' producer_thread = Thread( target=producer, - args=(queue, dataset, cfg.policy.batch_size, 'cuda:0,1' if cfg.policy.cuda else 'cpu'), + args=(queue, dataset, cfg.policy.batch_size, device), name='cuda_fetcher_producer' ) @@ -256,11 +258,10 @@ def _fetch(ctx: "OfflineRLContext"): ctx.train_epoch += 1 del dataloader dataloader = iter( - DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x) + DataLoader(dataset, batch_size=cfg.policy.batch_size, shuffle=True, collate_fn=lambda x: x) ) ctx.train_data = next(dataloader) # TODO apply data update (e.g. priority) in offline setting when necessary - return _fetch diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index 8b898dcc89..b8440aed13 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -113,12 +113,12 @@ def __init__( self.act_dim = act_dim self.h_dim = h_dim - ### transformer blocks + # transformer blocks input_seq_len = 3 * context_len blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] self.transformer = nn.Sequential(*blocks) - ### projection heads (project to embedding) + # projection heads (project to embedding) self.embed_ln = nn.LayerNorm(h_dim) self.embed_timestep = nn.Embedding(max_timestep, h_dim) self.embed_rtg = torch.nn.Linear(1, h_dim) @@ -142,7 +142,7 @@ def __init__( self.embed_action = torch.nn.Embedding(act_dim, h_dim) use_action_tanh = False # False for discrete actions - ### prediction heads + # prediction heads self.predict_action = nn.Sequential(*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))) def forward(self, timesteps, states, actions, returns_to_go): diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index bfbc777f4d..2570328afa 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -522,9 +522,8 @@ def __init__(self, cfg: dict) -> None: curr_num_transitions = len(obss) trajectories_to_load = cfg.dataset.trajectories_per_buffer while not done: - states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch( - batch_size=1, indices=[i] - ) + states, ac, ret, next_states, next_action, next_reward, terminal, indices = \ + frb.sample_transition_batch( batch_size=1, indices=[i]) states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) obss.append(states) actions.append(ac[0]) @@ -547,10 +546,6 @@ def __init__(self, cfg: dict) -> None: done = True num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load) transitions_per_buffer[buffer_num] = i - print( - 'this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' - % (i, len(obss), num_trajectories) - ) actions = np.array(actions) returns = np.array(returns) From 126c1e3fcd23a90cd296710c862f66581a921f6e Mon Sep 17 00:00:00 2001 From: luyudong Date: Tue, 8 Aug 2023 19:08:55 +0800 Subject: [PATCH 12/25] Change data fatcher func to class --- ding/framework/middleware/__init__.py | 1 + ding/framework/middleware/data_fetcher.py | 62 +++++++++++++++++++++++ ding/torch_utils/network/transformer.py | 51 ------------------- ding/utils/data/dataset.py | 2 - 4 files changed, 63 insertions(+), 53 deletions(-) create mode 100644 ding/framework/middleware/data_fetcher.py diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index 6ff67d8301..74ee25950f 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -4,3 +4,4 @@ from .ckpt_handler import CkptSaver from .distributer import ContextExchanger, ModelExchanger from .barrier import Barrier, BarrierRuntime +from .data_fetcher import offline_data_fetcher_from_mem_c \ No newline at end of file diff --git a/ding/framework/middleware/data_fetcher.py b/ding/framework/middleware/data_fetcher.py new file mode 100644 index 0000000000..02ddc20aa9 --- /dev/null +++ b/ding/framework/middleware/data_fetcher.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING +from threading import Thread +from queue import Queue +import time +import torch +from easydict import EasyDict +from ding.data import Dataset, DataLoader +from ding.utils import get_rank +import numpy as np + +if TYPE_CHECKING: + from ding.framework import OfflineRLContext + + +class offline_data_fetcher_from_mem_c: + + def __init__(self, cfg: EasyDict, dataset: Dataset): + stream = torch.cuda.Stream() + def producer(queue, dataset, batch_size, device): + torch.set_num_threads(4) + nonlocal stream + idx_iter = iter(np.random.permutation(len(dataset))) + + with torch.cuda.stream(stream): + while True: + if queue.full(): + time.sleep(0.1) + else: + try: + start_idx = next(idx_iter) + except StopIteration: + del idx_iter + idx_iter = iter(np.random.permutation(len(dataset))) + start_idx = next(idx_iter) + + data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)] + data = [[i[j] for i in data] for j in range(len(data[0]))] + try: + data = [torch.stack(x).to(device) for x in data] + except RuntimeError: + print(len(data)) + for i in range(len(data)): + print(len(data[i])) + print(data[i]) + queue.put(data) + + self.queue = Queue(maxsize=50) + device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' + print('prepare sample data in device', device) + self.producer_thread = Thread( + target=producer, + args=(self.queue, dataset, cfg.policy.batch_size, device), + name='cuda_fetcher_producer' + ) + + def __call__(self,ctx: "OfflineRLContext"): + if not self.producer_thread.is_alive(): + time.sleep(5) + self.producer_thread.start() + while self.queue.empty(): + time.sleep(0.001) + ctx.train_data = self.queue.get() \ No newline at end of file diff --git a/ding/torch_utils/network/transformer.py b/ding/torch_utils/network/transformer.py index a5da9ac6ce..e707134a3f 100644 --- a/ding/torch_utils/network/transformer.py +++ b/ding/torch_utils/network/transformer.py @@ -81,57 +81,6 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch return attention -class MaskedCausalAttention(nn.Module): - - def __init__(self, h_dim, max_T, n_heads, drop_p): - super().__init__() - - self.n_heads = n_heads - self.max_T = max_T - - self.q_net = nn.Linear(h_dim, h_dim) - self.k_net = nn.Linear(h_dim, h_dim) - self.v_net = nn.Linear(h_dim, h_dim) - - self.proj_net = nn.Linear(h_dim, h_dim) - - self.att_drop = nn.Dropout(drop_p) - self.proj_drop = nn.Dropout(drop_p) - - ones = torch.ones((max_T, max_T)) - mask = torch.tril(ones).view(1, 1, max_T, max_T) - - # register buffer makes sure mask does not get updated - # during backpropagation - self.register_buffer('mask', mask) - - def forward(self, x): - B, T, C = x.shape # batch size, seq length, h_dim * n_heads - - N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim - - # rearrange q, k, v as (B, N, T, D) - q = self.q_net(x).view(B, T, N, D).transpose(1, 2) - k = self.k_net(x).view(B, T, N, D).transpose(1, 2) - v = self.v_net(x).view(B, T, N, D).transpose(1, 2) - - # weights (B, N, T, T) - weights = q @ k.transpose(2, 3) / math.sqrt(D) - # causal mask applied to weights - weights = weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf')) - # normalize weights, all -inf -> 0 after softmax - normalized_weights = F.softmax(weights, dim=-1) - - # attention (B, N, T, D) - attention = self.att_drop(normalized_weights @ v) - - # gather heads and project (B, N, T, D) -> (B, T, N*D) - attention = attention.transpose(1, 2).contiguous().view(B, T, N * D) - - out = self.proj_drop(self.proj_net(attention)) - return out - - class TransformerLayer(nn.Module): r""" Overview: diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 2570328afa..abed2a8fdb 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -505,7 +505,6 @@ def __init__(self, cfg: dict) -> None: while len(obss) < cfg.dataset.num_steps: buffer_num = np.random.choice(np.arange(50 - cfg.dataset.num_buffers, 50), 1)[0] i = transitions_per_buffer[buffer_num] - print('loading from buffer %d which has %d already loaded' % (buffer_num, i)) frb = FixedReplayBuffer( data_dir=cfg.dataset.data_dir_prefix + '/1/replay_logs', replay_suffix=buffer_num, @@ -702,7 +701,6 @@ def _load_buffer(self, suffix): # pytype: disable=attribute-error replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(*self._args, **self._kwargs) replay_buffer.load(self._data_dir, suffix) - print('Loaded replay buffer ckpt {} from {}'.format(suffix, self._data_dir)) # pytype: enable=attribute-error return replay_buffer # except tf.errors.NotFoundError: From 1835242f23705630991609c15ba1c51ead84b2db Mon Sep 17 00:00:00 2001 From: luyudong Date: Wed, 9 Aug 2023 10:08:01 +0800 Subject: [PATCH 13/25] Add threading shift data to gpu --- ding/framework/middleware/data_fetcher.py | 13 +++---------- .../middleware/functional/data_processor.py | 9 +++++---- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/ding/framework/middleware/data_fetcher.py b/ding/framework/middleware/data_fetcher.py index 02ddc20aa9..022f7fc833 100644 --- a/ding/framework/middleware/data_fetcher.py +++ b/ding/framework/middleware/data_fetcher.py @@ -19,7 +19,7 @@ def __init__(self, cfg: EasyDict, dataset: Dataset): def producer(queue, dataset, batch_size, device): torch.set_num_threads(4) nonlocal stream - idx_iter = iter(np.random.permutation(len(dataset))) + idx_iter = iter(np.random.permutation(len(dataset.obss)-batch_size)) with torch.cuda.stream(stream): while True: @@ -30,23 +30,16 @@ def producer(queue, dataset, batch_size, device): start_idx = next(idx_iter) except StopIteration: del idx_iter - idx_iter = iter(np.random.permutation(len(dataset))) + idx_iter = iter(np.random.permutation(len(dataset.obss)-batch_size)) start_idx = next(idx_iter) data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)] data = [[i[j] for i in data] for j in range(len(data[0]))] - try: - data = [torch.stack(x).to(device) for x in data] - except RuntimeError: - print(len(data)) - for i in range(len(data)): - print(len(data[i])) - print(data[i]) + data = [torch.stack(x).to(device) for x in data] queue.put(data) self.queue = Queue(maxsize=50) device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' - print('prepare sample data in device', device) self.producer_thread = Thread( target=producer, args=(self.queue, dataset, cfg.policy.batch_size, device), diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 7c31ff2166..295123a8bc 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -189,7 +189,7 @@ def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: stream = torch.cuda.Stream() def producer(queue, dataset, batch_size, device): - torch.set_num_threads(8) + torch.set_num_threads(4) nonlocal stream idx_iter = iter(range(len(dataset))) with torch.cuda.stream(stream): @@ -216,10 +216,11 @@ def producer(queue, dataset, batch_size, device): name='cuda_fetcher_producer' ) - producer_thread.start() - def _fetch(ctx: "OfflineRLContext"): - nonlocal queue + nonlocal queue, producer_thread + if not producer_thread.is_alive(): + time.sleep(5) + producer_thread.start() while queue.empty(): time.sleep(0.001) ctx.train_data = queue.get() From 3406ab1bb7ea84e82fb7fb603e5aa342c7fb3264 Mon Sep 17 00:00:00 2001 From: luyudong Date: Wed, 9 Aug 2023 10:09:04 +0800 Subject: [PATCH 14/25] Change action sample func --- ding/policy/dt.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/ding/policy/dt.py b/ding/policy/dt.py index f478a1b78d..ab9933e991 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -69,7 +69,11 @@ def _init_learn(self) -> None: self.act_dim = self._cfg.model.act_dim self._learn_model = self._model - self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) + + if 'state_mean' not in self._cfg: + self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr) + else: + self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) self._scheduler = torch.optim.lr_scheduler.LambdaLR( self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1) @@ -86,6 +90,7 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: Returns: - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. """ + self._learn_model.train() timesteps, states, actions, returns_to_go, traj_mask = data action_target = torch.clone(actions).detach().to(self._device) @@ -99,9 +104,14 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: if not self._cfg.model.continuous and 'state_mean' in self._cfg: actions = one_hot(actions.squeeze(-1), num=self.act_dim) - state_preds, action_preds, return_preds = self._learn_model.forward( - timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go - ) + if 'state_mean' not in self._cfg: + state_preds, action_preds, return_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1 + ) + else: + state_preds, action_preds, return_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go + ) if 'state_mean' not in self._cfg: action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) @@ -136,7 +146,6 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``, initialize eval_model. """ self._eval_model = self._model - # self._eval_model.reset() # init data self._device = torch.device(self._device) self.rtg_scale = self._cfg.rtg_scale # normalize returns to go @@ -223,8 +232,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: if self.t[i] <= self.context_len: if 'state_mean' not in self._cfg: timesteps[i] = min(self.t[i], - self._cfg.max_timestep) * torch.ones((1, 1), - dtype=torch.int64).to(self._device) + self._cfg.model.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device) else: timesteps[i] = self.timesteps[i, :self.context_len] states[i] = self.states[i, :self.context_len] @@ -233,23 +241,28 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: else: if 'state_mean' not in self._cfg: timesteps[i] = min(self.t[i], - self._cfg.max_timestep) * torch.ones((1, 1), - dtype=torch.int64).to(self._device) + self._cfg.model.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device) else: timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] - # if not self._cfg.model.continuous: - # actions = one_hot(actions.squeeze(-1), num=self.act_dim) + if not self._cfg.model.continuous and 'state_mean' in self._cfg: + actions = one_hot(actions.squeeze(-1), num=self.act_dim) _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go) del timesteps, states, actions, rewards_to_go logits = act_preds[:, -1, :] if not self._cfg.model.continuous: - act = torch.argmax(logits, axis=1).unsqueeze(1) + if 'state_mean' not in self._cfg: + probs = F.softmax(logits, dim=-1) + act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device) + for i in data_id: + act[i] = torch.multinomial(probs[i], num_samples=1) + else: + act = torch.argmax(logits, axis=1).unsqueeze(1) for i in data_id: - self.actions[i, self.t[i]] = act[i] + self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t self.t[i] += 1 if self._cuda: From 49bfb3cc57acaab0b16fb1b5aaf8a901a59e5e1b Mon Sep 17 00:00:00 2001 From: luyudong Date: Wed, 9 Aug 2023 10:09:46 +0800 Subject: [PATCH 15/25] Add configure optimizers --- ding/model/template/dt.py | 150 ++++++++++++++++++++++++++------------ 1 file changed, 104 insertions(+), 46 deletions(-) diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index b8440aed13..5869bbae22 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -89,11 +89,12 @@ def forward(self, x): x = self.ln1(x) x = x + self.mlp(x) # residual x = self.ln2(x) + # x = x + self.attention(self.ln1(x)) + # x = x + self.mlp(self.ln2(x)) return x class DecisionTransformer(nn.Module): - def __init__( self, state_dim, @@ -115,37 +116,39 @@ def __init__( # transformer blocks input_seq_len = 3 * context_len - blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] - self.transformer = nn.Sequential(*blocks) # projection heads (project to embedding) self.embed_ln = nn.LayerNorm(h_dim) self.embed_timestep = nn.Embedding(max_timestep, h_dim) - self.embed_rtg = torch.nn.Linear(1, h_dim) - + self.drop = nn.Dropout(drop_p) + self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim)) if state_encoder == None: + blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] + self.embed_rtg = torch.nn.Linear(1, h_dim) self.embed_state = torch.nn.Linear(state_dim, h_dim) self.predict_rtg = torch.nn.Linear(h_dim, 1) self.predict_state = torch.nn.Linear(h_dim, state_dim) + self.predict_action = nn.Sequential(*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))) + if continuous: + # continuous actions + self.embed_action = torch.nn.Linear(act_dim, h_dim) + use_action_tanh = True # True for continuous actions + else: + # discrete actions + self.embed_action = torch.nn.Embedding(act_dim, h_dim) + use_action_tanh = False # False for discrete actions else: + blocks = [Block(h_dim, input_seq_len+1, n_heads, drop_p) for _ in range(n_blocks)] self.state_encoder = state_encoder + self.embed_rtg = nn.Sequential(nn.Linear(1, h_dim), nn.Tanh()) + self.head = nn.Linear(h_dim, act_dim, bias=False) + self.embed_action = nn.Sequential(nn.Embedding(act_dim, h_dim), nn.Tanh()) + self.transformer = nn.Sequential(*blocks) - if continuous: - # continuous actions - self.embed_action = torch.nn.Linear(act_dim, h_dim) - use_action_tanh = True # True for continuous actions - else: - # discrete actions - self.embed_action = torch.nn.Embedding(act_dim, h_dim) - use_action_tanh = False # False for discrete actions - - # prediction heads - self.predict_action = nn.Sequential(*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))) - - def forward(self, timesteps, states, actions, returns_to_go): + def forward(self, timesteps, states, actions, returns_to_go, tar=None): B, T = states.shape[0], states.shape[1] if self.state_encoder == None: time_embeddings = self.embed_timestep(timesteps) @@ -154,48 +157,103 @@ def forward(self, timesteps, states, actions, returns_to_go): state_embeddings = self.embed_state(states) + time_embeddings action_embeddings = self.embed_action(actions) + time_embeddings returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings + + # stack rtg, states and actions and reshape sequence as + # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) + t_p = torch.stack( + (returns_embeddings, state_embeddings, action_embeddings), dim=1 + ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) + h = self.embed_ln(t_p) + # transformer and prediction + h = self.transformer(h) + # get h reshaped such that its size = (B x 3 x T x h_dim) and + # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t + # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t + # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t + # that is, for each timestep (t) we have 3 output embeddings from the transformer, + # each conditioned on all previous timesteps plus + # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. + h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) + + return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a + state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a + action_preds = self.predict_action(h[:, 1]) # predict action given r, s else: state_embeddings = self.state_encoder( states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous() ) # (batch * block_size, h_dim) state_embeddings = state_embeddings.reshape( - states.shape[0], states.shape[1], self.h_dim + B, T, self.h_dim ) # (batch, block_size, h_dim) returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32)) action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim) - # stack rtg, states and actions and reshape sequence as - # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) - h = torch.stack((returns_embeddings, state_embeddings, action_embeddings), - dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) + token_embeddings = torch.zeros((B, T*3 - int(tar is None), self.h_dim), dtype=torch.float32, device=state_embeddings.device) + token_embeddings[:,::3,:] = returns_embeddings + token_embeddings[:,1::3,:] = state_embeddings + token_embeddings[:,2::3,:] = action_embeddings[:,-T + int(tar is None):,:] + + all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, B, dim=0) # batch_size, traj_length, h_dim - all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, B, dim=0) # batch_size, traj_length, h_dim - position_embeddings = torch.gather( - all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1) - ) - position_embeddings = position_embeddings + self.pos_emb[:, :h.shape[1], :] + position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :] - h = self.embed_ln(h + position_embeddings) + t_p = token_embeddings + position_embeddings - # transformer and prediction - h = self.transformer(h) + h = self.drop(t_p) + h = self.transformer(h) + h = self.embed_ln(h) + logits = self.head(h) - # get h reshaped such that its size = (B x 3 x T x h_dim) and - # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t - # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t - # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t - # that is, for each timestep (t) we have 3 output embeddings from the transformer, - # each conditioned on all previous timesteps plus - # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. - h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) - - # get predictions - if self.state_encoder == None: - return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a - state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a - else: return_preds = None state_preds = None - action_preds = self.predict_action(h[:, 1]) # predict action given r, s + action_preds = logits[:, 1::3, :] # only keep predictions from state_embeddings return state_preds, action_preds, return_preds + + def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + # whitelist_weight_modules = (torch.nn.Linear, ) + whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add('pos_emb') + no_decay.add('global_pos_emb') + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer \ No newline at end of file From bb158451a2b4235bb528c9303d47f767d927cf8a Mon Sep 17 00:00:00 2001 From: luyudong Date: Wed, 9 Aug 2023 10:10:23 +0800 Subject: [PATCH 16/25] Add multi gpu support --- ding/utils/pytorch_ddp_dist_helper.py | 4 +- .../config/serial/pong/pong_dt_config.py | 6 +- dizoo/atari/entry/atari_dt_main.py | 62 +++++++++---------- 3 files changed, 35 insertions(+), 37 deletions(-) diff --git a/ding/utils/pytorch_ddp_dist_helper.py b/ding/utils/pytorch_ddp_dist_helper.py index 312f3005a9..60092c5c42 100644 --- a/ding/utils/pytorch_ddp_dist_helper.py +++ b/ding/utils/pytorch_ddp_dist_helper.py @@ -181,6 +181,6 @@ def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: def to_ddp_config(cfg: EasyDict) -> EasyDict: w = get_world_size() - cfg.policy.learn.batch_size = int(np.ceil(cfg.policy.learn.batch_size / w)) - cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample) / w) + cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w)) + # cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample) / w) return cfg diff --git a/dizoo/atari/config/serial/pong/pong_dt_config.py b/dizoo/atari/config/serial/pong/pong_dt_config.py index d4e8222434..d7d020d523 100644 --- a/dizoo/atari/config/serial/pong/pong_dt_config.py +++ b/dizoo/atari/config/serial/pong/pong_dt_config.py @@ -7,7 +7,6 @@ env_id='PongNoFrameskip-v4', collector_env_num=1, evaluator_env_num=8, - use_act_scale=True, n_evaluator_episode=8, stop_value=20, frame_stack=4, @@ -16,8 +15,8 @@ ), dataset=dict( env_type='atari', - # num_steps=500000, - num_steps=500, + num_steps=500000, + # num_steps=50, num_buffers=50, rtg_scale=None, context_len=30, @@ -26,6 +25,7 @@ ), policy=dict( cuda=True, + multi_gpu=True, stop_value=20, evaluator_env_num=8, env_name='PongNoFrameskip-v4', diff --git a/dizoo/atari/entry/atari_dt_main.py b/dizoo/atari/entry/atari_dt_main.py index ccf919488b..a538ca234e 100644 --- a/dizoo/atari/entry/atari_dt_main.py +++ b/dizoo/atari/entry/atari_dt_main.py @@ -1,55 +1,53 @@ -import gym -import torch -import numpy as np import torch.nn as nn from ditk import logging from ding.model.template.dt import DecisionTransformer from ding.policy import DTPolicy -from ding.envs import BaseEnvManagerV2, SyncSubprocessEnvManager, SubprocessEnvManagerV2 +from ding.envs import SubprocessEnvManagerV2 from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper from ding.data import create_dataset from ding.config import compile_config from ding.framework import task, ding_init from ding.framework.context import OfflineRLContext -from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger, termination_checker, offline_data_fetcher_from_mem -from ding.utils import set_pkg_seed +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_logger, termination_checker, offline_data_fetcher_from_mem_c +from ding.utils import set_pkg_seed, DDPContext, to_ddp_config from dizoo.atari.envs import AtariEnv from dizoo.atari.config.serial.pong.pong_dt_config import main_config, create_config -import os -from functools import partial -os.environ['CUDA_LAUNCH_BLOCKING'] = "1" + def main(): # If you don't have offline data, you need to prepare if first and set the data_path in config # For demostration, we also can train a RL policy (e.g. SAC) and collect some data logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, auto=True) + cmain_config = to_ddp_config(main_config) + cfg = compile_config(cmain_config, create_cfg=create_config, auto=True) ding_init(cfg) - with task.start(async_mode=False, ctx=OfflineRLContext()): - evaluator_env = SubprocessEnvManagerV2( - env_fn=[lambda: AllinObsWrapper(AtariEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager - ) - - set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + with DDPContext(): + with task.start(async_mode=False, ctx=OfflineRLContext()): + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AllinObsWrapper(AtariEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager + ) - dataset = create_dataset(cfg) - cfg.policy.max_timestep = dataset.get_max_timestep() - # dataset = get_data_source(dataset) - state_encoder = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), - nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), - nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), - nn.Flatten(), nn.Linear(3136, cfg.policy.model.h_dim), nn.Tanh()) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + dataset = create_dataset(cfg) + cfg.policy.model.max_timestep = dataset.get_max_timestep() + state_encoder = nn.Sequential( + nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), nn.Conv2d(32, 64, 4, stride=2, padding=0), + nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), nn.Flatten(), + nn.Linear(3136, cfg.policy.model.h_dim), nn.Tanh() + ) - model = DecisionTransformer(**cfg.policy.model, state_encoder=state_encoder) - policy = DTPolicy(cfg.policy, model=model) + model = DecisionTransformer(**cfg.policy.model, state_encoder=state_encoder) + # model.parallelize() + policy = DTPolicy(cfg.policy, model=model) - task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) - task.use(offline_data_fetcher_from_mem(cfg, dataset)) - task.use(trainer(cfg, policy.learn_mode)) - task.use(termination_checker(max_train_iter=1e5)) - task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) - task.use(offline_logger(cfg.exp_name)) - task.run() + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(offline_data_fetcher_from_mem_c(cfg, dataset)) + task.use(trainer(cfg, policy.learn_mode)) + task.use(termination_checker(max_train_iter=3e4)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) + task.use(offline_logger(cfg.exp_name)) + task.run() if __name__ == "__main__": From 9712a29539a4631b41b4497a4adffd72e2337b12 Mon Sep 17 00:00:00 2001 From: luyudong Date: Wed, 9 Aug 2023 14:10:23 +0800 Subject: [PATCH 17/25] Add dt policy test serial --- ding/entry/tests/test_serial_entry.py | 64 ++++++++++++++++++ ding/framework/middleware/data_fetcher.py | 25 +++++-- ding/framework/task.py | 1 + ding/model/template/dt.py | 3 +- ding/policy/dt.py | 13 ++-- ding/utils/data/dataset.py | 18 ++++- .../config/serial/pong/pong_dt_config.py | 1 - dizoo/atari/entry/atari_dt_main.py | 2 +- .../cartpole/config/cartpole_dt_config.py | 65 +++++++++++++++++++ 9 files changed, 174 insertions(+), 18 deletions(-) create mode 100644 dizoo/classic_control/cartpole/config/cartpole_dt_config.py diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index ecccb8a98e..ce1b08feec 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -45,6 +45,7 @@ from dizoo.classic_control.pendulum.config.pendulum_cql_config import pendulum_cql_config, pendulum_cql_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import cartpole_qrdqn_generation_data_config, cartpole_qrdqn_generation_data_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_cql_config import cartpole_discrete_cql_config, cartpole_discrete_cql_create_config # noqa +from dizoo.classic_control.cartpole.config.cartpole_dt_config import cartpole_discrete_dt_config, cartpole_discrete_dt_create_config # noqa from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config # noqa from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa from dizoo.classic_control.pendulum.config.pendulum_ibc_config import pendulum_ibc_config, pendulum_ibc_create_config @@ -621,6 +622,69 @@ def test_discrete_cql(): os.popen('rm -rf cartpole cartpole_cql') +@pytest.mark.platformtest +@pytest.mark.unittest +def test_discrete_dt(): + # train expert + config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)] + config[0].policy.learn.update_per_collect = 1 + config[0].exp_name = 'dt_cartpole' + try: + serial_pipeline(config, seed=0, max_train_iter=1) + except Exception: + assert False, "pipeline fail" + # collect expert data + import torch + config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] + state_dict = torch.load('./dt_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu') + try: + collect_demo_data(config, seed=0, collect_count=1000, state_dict=state_dict) + except Exception as e: + assert False, "pipeline fail" + print(repr(e)) + + # train dt + config = [deepcopy(cartpole_discrete_dt_config), deepcopy(cartpole_discrete_dt_create_config)] + config[0].policy.eval.evaluator.eval_freq = 5 + try: + from ding.framework import task + from ding.framework.context import OfflineRLContext + from ding.envs import SubprocessEnvManagerV2, BaseEnvManagerV2 + from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper + from dizoo.classic_control.cartpole.envs import CartPoleEnv + from ding.utils import set_pkg_seed + from ding.data import create_dataset + from ding.config import compile_config + from ding.model.template.dt import DecisionTransformer + from ding.policy import DTPolicy + from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher_from_mem_c, offline_logger, termination_checker + config = compile_config(config[0], create_cfg=config[1], auto=True) + with task.start(async_mode=False, ctx=OfflineRLContext()): + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: AllinObsWrapper(CartPoleEnv(config.env)) for _ in range(config.env.evaluator_env_num)], + cfg=config.env.manager + ) + + set_pkg_seed(config.seed, use_cuda=config.policy.cuda) + + dataset = create_dataset(config) + + model = DecisionTransformer(**config.policy.model) + policy = DTPolicy(config.policy, model=model) + + task.use(termination_checker(max_train_iter=1)) + task.use(interaction_evaluator(config, policy.eval_mode, evaluator_env)) + task.use(offline_data_fetcher_from_mem_c(config, dataset)) + task.use(trainer(config, policy.learn_mode)) + task.use(CkptSaver(policy, config.exp_name, train_freq=100)) + task.use(offline_logger(config.exp_name)) + task.run() + except Exception: + assert False, "pipeline fail" + finally: + os.popen('rm -rf cartpole cartpole_dt') +test_discrete_dt() + @pytest.mark.platformtest @pytest.mark.unittest def test_td3_bc(): diff --git a/ding/framework/middleware/data_fetcher.py b/ding/framework/middleware/data_fetcher.py index 022f7fc833..c82ee8f1d6 100644 --- a/ding/framework/middleware/data_fetcher.py +++ b/ding/framework/middleware/data_fetcher.py @@ -1,9 +1,10 @@ from typing import TYPE_CHECKING -from threading import Thread +from threading import Thread, Event from queue import Queue import time import torch from easydict import EasyDict +from ding.framework import task from ding.data import Dataset, DataLoader from ding.utils import get_rank import numpy as np @@ -14,12 +15,17 @@ class offline_data_fetcher_from_mem_c: + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.FETCHER): + return task.void() + return super(offline_data_fetcher_from_mem_c, cls).__new__(cls) + def __init__(self, cfg: EasyDict, dataset: Dataset): stream = torch.cuda.Stream() - def producer(queue, dataset, batch_size, device): + def producer(queue, dataset, batch_size, device, event): torch.set_num_threads(4) nonlocal stream - idx_iter = iter(np.random.permutation(len(dataset.obss)-batch_size)) + idx_iter = iter(np.random.permutation(len(dataset)-batch_size)) with torch.cuda.stream(stream): while True: @@ -30,19 +36,21 @@ def producer(queue, dataset, batch_size, device): start_idx = next(idx_iter) except StopIteration: del idx_iter - idx_iter = iter(np.random.permutation(len(dataset.obss)-batch_size)) + idx_iter = iter(np.random.permutation(len(dataset)-batch_size)) start_idx = next(idx_iter) - data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)] data = [[i[j] for i in data] for j in range(len(data[0]))] data = [torch.stack(x).to(device) for x in data] queue.put(data) + if event.is_set(): + break self.queue = Queue(maxsize=50) + self.event = Event() device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' self.producer_thread = Thread( target=producer, - args=(self.queue, dataset, cfg.policy.batch_size, device), + args=(self.queue, dataset, cfg.policy.batch_size, device, self.event), name='cuda_fetcher_producer' ) @@ -52,4 +60,7 @@ def __call__(self,ctx: "OfflineRLContext"): self.producer_thread.start() while self.queue.empty(): time.sleep(0.001) - ctx.train_data = self.queue.get() \ No newline at end of file + ctx.train_data = self.queue.get() + if task.finish: + self.event.set() + del self.queue diff --git a/ding/framework/task.py b/ding/framework/task.py index 081eb8cdae..e5f8e7d9f7 100644 --- a/ding/framework/task.py +++ b/ding/framework/task.py @@ -57,6 +57,7 @@ class Role(str, enum.Enum): LEARNER = "learner" COLLECTOR = "collector" EVALUATOR = "evaluator" + FETCHER = 'fetcher' class VoidMiddleware: diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index 5869bbae22..1e7eee0de2 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -126,12 +126,12 @@ def __init__( self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim)) if state_encoder == None: + self.state_encoder = None blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] self.embed_rtg = torch.nn.Linear(1, h_dim) self.embed_state = torch.nn.Linear(state_dim, h_dim) self.predict_rtg = torch.nn.Linear(h_dim, 1) self.predict_state = torch.nn.Linear(h_dim, state_dim) - self.predict_action = nn.Sequential(*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))) if continuous: # continuous actions self.embed_action = torch.nn.Linear(act_dim, h_dim) @@ -140,6 +140,7 @@ def __init__( # discrete actions self.embed_action = torch.nn.Embedding(act_dim, h_dim) use_action_tanh = False # False for discrete actions + self.predict_action = nn.Sequential(*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))) else: blocks = [Block(h_dim, input_seq_len+1, n_heads, drop_p) for _ in range(n_blocks)] self.state_encoder = state_encoder diff --git a/ding/policy/dt.py b/ding/policy/dt.py index ab9933e991..b559ee2dca 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -5,6 +5,7 @@ from collections import namedtuple import torch.nn.functional as F import torch +import numpy as np from ding.torch_utils import to_device from ding.utils import POLICY_REGISTRY from ding.utils.data import default_decollate @@ -49,7 +50,6 @@ def _init_learn(self) -> None: Learn mode init method. Called by ``self.__init__``. Init the optimizer, algorithm config, main and target models. """ - self.env_name = self._cfg.env_name # rtg_scale: scale of `return to go` # rtg_target: max target of `return to go` # Our goal is normalize `return to go` to (0, 1), which will favour the covergence. @@ -102,7 +102,9 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: # if discrete if not self._cfg.model.continuous and 'state_mean' in self._cfg: - actions = one_hot(actions.squeeze(-1), num=self.act_dim) + # actions = one_hot(actions.squeeze(-1), num=self.act_dim) + actions = actions.squeeze(-1) + action_target = torch.clone(actions).detach().to(self._device) if 'state_mean' not in self._cfg: state_preds, action_preds, return_preds = self._learn_model.forward( @@ -180,8 +182,8 @@ def _init_eval(self) -> None: self.states = torch.zeros( (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device ) - self.state_mean = torch.from_numpy(self._cfg.state_mean).to(self._device) - self.state_std = torch.from_numpy(self._cfg.state_std).to(self._device) + self.state_mean = torch.from_numpy(np.array(self._cfg.state_mean)).to(self._device) + self.state_std = torch.from_numpy(np.array(self._cfg.state_std)).to(self._device) self.timesteps = torch.arange( start=0, end=self.max_eval_ep_len, step=1 ).repeat(self.eval_batch_size, 1).to(self._device) @@ -248,7 +250,8 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] if not self._cfg.model.continuous and 'state_mean' in self._cfg: - actions = one_hot(actions.squeeze(-1), num=self.act_dim) + # actions = one_hot(actions.squeeze(-1), num=self.act_dim) + actions = actions.squeeze(-1) _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go) del timesteps, states, actions, rewards_to_go diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index abed2a8fdb..ea6e67bc67 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -132,6 +132,7 @@ def __init__(self, cfg: dict) -> None: logging.warning("not found h5py package, please install it trough `pip install h5py ") sys.exit(1) data_path = cfg.policy.collect.get('data_path', None) + self.context_len = cfg.dataset.context_len data = h5py.File(data_path, 'r') self._load_data(data) self._cal_statistics() @@ -143,10 +144,21 @@ def __init__(self, cfg: dict) -> None: pass def __len__(self) -> int: - return len(self._data['obs']) + return len(self._data['obs']) - self.context_len def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: - return {k: self._data[k][idx] for k in self._data.keys()} + # return {k: self._data[k][idx] for k in self._data.keys()} + block_size = self.context_len + done_idx = idx + block_size + idx = done_idx - block_size + states = torch.as_tensor( + np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32 + ).view(block_size, -1) + actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long) + rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32) + timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64) + traj_mask = torch.ones(self.context_len, dtype=torch.long) + return timesteps, states, actions, rtgs, traj_mask def _load_data(self, dataset: Dict[str, np.ndarray]) -> None: self._data = {} @@ -590,7 +602,7 @@ def __len__(self) -> int: if self.env_type != 'atari': return len(self.trajectories) else: - return len(self.obss) - self.context_len * 3 + return len(self.obss) - self.context_len def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if self.env_type != 'atari': diff --git a/dizoo/atari/config/serial/pong/pong_dt_config.py b/dizoo/atari/config/serial/pong/pong_dt_config.py index d7d020d523..5957db557e 100644 --- a/dizoo/atari/config/serial/pong/pong_dt_config.py +++ b/dizoo/atari/config/serial/pong/pong_dt_config.py @@ -28,7 +28,6 @@ multi_gpu=True, stop_value=20, evaluator_env_num=8, - env_name='PongNoFrameskip-v4', rtg_target=20, # max target return to go max_eval_ep_len=10000, # max lenght of one episode wt_decay=1e-4, diff --git a/dizoo/atari/entry/atari_dt_main.py b/dizoo/atari/entry/atari_dt_main.py index a538ca234e..f56ecebcaa 100644 --- a/dizoo/atari/entry/atari_dt_main.py +++ b/dizoo/atari/entry/atari_dt_main.py @@ -41,10 +41,10 @@ def main(): # model.parallelize() policy = DTPolicy(cfg.policy, model=model) + task.use(termination_checker(max_train_iter=3e4)) task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(offline_data_fetcher_from_mem_c(cfg, dataset)) task.use(trainer(cfg, policy.learn_mode)) - task.use(termination_checker(max_train_iter=3e4)) task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.use(offline_logger(cfg.exp_name)) task.run() diff --git a/dizoo/classic_control/cartpole/config/cartpole_dt_config.py b/dizoo/classic_control/cartpole/config/cartpole_dt_config.py new file mode 100644 index 0000000000..4fe5536270 --- /dev/null +++ b/dizoo/classic_control/cartpole/config/cartpole_dt_config.py @@ -0,0 +1,65 @@ +from easydict import EasyDict + +cartpole_discrete_dt_config = dict( + exp_name='cartpole_dt_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, + ), + dataset=dict( + data_dir_prefix='./cartpole_qrdqn_generation_data_seed0/expert_demos.hdf5', + rtg_scale=None, + context_len=20, + env_type='classic', + ), + policy=dict( + cuda=False, + rtg_target=10, + evaluator_env_num=5, + clip_grad_norm_p=1.0, + state_mean=1, + state_std=0, + model=dict( + state_dim=4, + act_dim=2, + n_blocks=6, + h_dim=128, + context_len=20, + n_heads=8, + drop_p=0.1, + continuous=False, + ), + max_timestep=1000, + discount_factor=0.97, + nstep=3, + batch_size=64, + learning_rate=0.001, + target_update_freq=100, + kappa=1.0, + min_q_weight=4.0, + collect=dict( + data_type='hdf5', + data_path='./cartpole_qrdqn_generation_data_seed0/expert_demos.hdf5', + ), + eval=dict(evaluator=dict(eval_freq=100, )), + ), +) +cartpole_discrete_dt_config = EasyDict(cartpole_discrete_dt_config) +main_config = cartpole_discrete_dt_config +cartpole_discrete_dt_create_config = dict( + env=dict( + type='cartpole', + import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='dt'), +) +cartpole_discrete_dt_create_config = EasyDict(cartpole_discrete_dt_create_config) +create_config = cartpole_discrete_dt_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_offline -c cartpole_dt_config.py -s 0` + from ding.entry import serial_pipeline_offline + serial_pipeline_offline((main_config, create_config), seed=0) From b8981846b23dee85cfc62dc0184e2f1ff642046e Mon Sep 17 00:00:00 2001 From: luyudong Date: Thu, 10 Aug 2023 13:40:24 +0800 Subject: [PATCH 18/25] Fix multi gpu support and data fetcher --- ding/framework/middleware/data_fetcher.py | 29 +++++++++++++------ .../middleware/functional/data_processor.py | 2 +- ding/policy/dt.py | 2 ++ dizoo/atari/entry/atari_dt_main.py | 11 +++---- 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/ding/framework/middleware/data_fetcher.py b/ding/framework/middleware/data_fetcher.py index c82ee8f1d6..fbceb69c53 100644 --- a/ding/framework/middleware/data_fetcher.py +++ b/ding/framework/middleware/data_fetcher.py @@ -3,6 +3,7 @@ from queue import Queue import time import torch +import torch.distributed as dist from easydict import EasyDict from ding.framework import task from ding.data import Dataset, DataLoader @@ -25,20 +26,28 @@ def __init__(self, cfg: EasyDict, dataset: Dataset): def producer(queue, dataset, batch_size, device, event): torch.set_num_threads(4) nonlocal stream - idx_iter = iter(np.random.permutation(len(dataset)-batch_size)) + num_gpu = dist.get_world_size() + rank = get_rank() + idx_list = np.random.permutation(len(dataset)) + temp_idx_list = [] + for i in range(len(dataset)//(batch_size*num_gpu)): + temp_idx_list.extend(idx_list[i+rank*batch_size:i+(rank+1)*batch_size]) + idx_iter = iter(temp_idx_list) with torch.cuda.stream(stream): while True: if queue.full(): time.sleep(0.1) else: - try: - start_idx = next(idx_iter) - except StopIteration: - del idx_iter - idx_iter = iter(np.random.permutation(len(dataset)-batch_size)) - start_idx = next(idx_iter) - data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)] + data = [] + for _ in range(batch_size): + try: + data.append(dataset.__getitem__(next(idx_iter))) + except StopIteration: + del idx_iter + idx_list = np.random.permutation(len(dataset)) + idx_iter = iter(idx_list) + data.append(dataset.__getitem__(next(idx_iter))) data = [[i[j] for i in data] for j in range(len(data[0]))] data = [torch.stack(x).to(device) for x in data] queue.put(data) @@ -61,6 +70,8 @@ def __call__(self,ctx: "OfflineRLContext"): while self.queue.empty(): time.sleep(0.001) ctx.train_data = self.queue.get() - if task.finish: + + def __del__(self): + if self.producer_thread.is_alive(): self.event.set() del self.queue diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 295123a8bc..64ee724135 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -240,7 +240,7 @@ def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable: - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data. """ # collate_fn is executed in policy now - dataloader = iter(DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)) + dataloader = iter(DataLoader(dataset, batch_size=cfg.policy.batch_size, shuffle=True, collate_fn=lambda x: x)) def _fetch(ctx: "OfflineRLContext"): """ diff --git a/ding/policy/dt.py b/ding/policy/dt.py index b559ee2dca..f769495a25 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -132,6 +132,8 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: self._optimizer.zero_grad() action_loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p) self._optimizer.step() self._scheduler.step() diff --git a/dizoo/atari/entry/atari_dt_main.py b/dizoo/atari/entry/atari_dt_main.py index f56ecebcaa..db549065d4 100644 --- a/dizoo/atari/entry/atari_dt_main.py +++ b/dizoo/atari/entry/atari_dt_main.py @@ -8,20 +8,21 @@ from ding.config import compile_config from ding.framework import task, ding_init from ding.framework.context import OfflineRLContext -from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_logger, termination_checker, offline_data_fetcher_from_mem_c +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_logger, termination_checker, offline_data_fetcher_from_mem_c, offline_data_fetcher from ding.utils import set_pkg_seed, DDPContext, to_ddp_config from dizoo.atari.envs import AtariEnv from dizoo.atari.config.serial.pong.pong_dt_config import main_config, create_config +import torch.distributed as dist def main(): # If you don't have offline data, you need to prepare if first and set the data_path in config # For demostration, we also can train a RL policy (e.g. SAC) and collect some data logging.getLogger().setLevel(logging.INFO) - cmain_config = to_ddp_config(main_config) - cfg = compile_config(cmain_config, create_cfg=create_config, auto=True) - ding_init(cfg) with DDPContext(): + cmain_config = to_ddp_config(main_config) + cfg = compile_config(cmain_config, create_cfg=create_config, auto=True) + ding_init(cfg) with task.start(async_mode=False, ctx=OfflineRLContext()): evaluator_env = SubprocessEnvManagerV2( env_fn=[lambda: AllinObsWrapper(AtariEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], @@ -41,10 +42,10 @@ def main(): # model.parallelize() policy = DTPolicy(cfg.policy, model=model) - task.use(termination_checker(max_train_iter=3e4)) task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(offline_data_fetcher_from_mem_c(cfg, dataset)) task.use(trainer(cfg, policy.learn_mode)) + task.use(termination_checker(max_train_iter=3e4)) task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.use(offline_logger(cfg.exp_name)) task.run() From 2d7c4062803a8b10b72500b15e4cccb22352f8ee Mon Sep 17 00:00:00 2001 From: luyudong Date: Thu, 10 Aug 2023 14:44:03 +0800 Subject: [PATCH 19/25] Reformat --- ding/entry/tests/test_serial_entry.py | 3 ++ ding/framework/middleware/data_fetcher.py | 7 +-- .../middleware/functional/data_processor.py | 5 +- ding/model/template/dt.py | 54 +++++++++++-------- ding/policy/dt.py | 12 +++-- ding/utils/data/dataset.py | 4 +- 6 files changed, 50 insertions(+), 35 deletions(-) diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index ce1b08feec..d146262856 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -683,8 +683,11 @@ def test_discrete_dt(): assert False, "pipeline fail" finally: os.popen('rm -rf cartpole cartpole_dt') + + test_discrete_dt() + @pytest.mark.platformtest @pytest.mark.unittest def test_td3_bc(): diff --git a/ding/framework/middleware/data_fetcher.py b/ding/framework/middleware/data_fetcher.py index fbceb69c53..ac13931082 100644 --- a/ding/framework/middleware/data_fetcher.py +++ b/ding/framework/middleware/data_fetcher.py @@ -23,6 +23,7 @@ def __new__(cls, *args, **kwargs): def __init__(self, cfg: EasyDict, dataset: Dataset): stream = torch.cuda.Stream() + def producer(queue, dataset, batch_size, device, event): torch.set_num_threads(4) nonlocal stream @@ -30,8 +31,8 @@ def producer(queue, dataset, batch_size, device, event): rank = get_rank() idx_list = np.random.permutation(len(dataset)) temp_idx_list = [] - for i in range(len(dataset)//(batch_size*num_gpu)): - temp_idx_list.extend(idx_list[i+rank*batch_size:i+(rank+1)*batch_size]) + for i in range(len(dataset) // (batch_size * num_gpu)): + temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) idx_iter = iter(temp_idx_list) with torch.cuda.stream(stream): @@ -63,7 +64,7 @@ def producer(queue, dataset, batch_size, device, event): name='cuda_fetcher_producer' ) - def __call__(self,ctx: "OfflineRLContext"): + def __call__(self, ctx: "OfflineRLContext"): if not self.producer_thread.is_alive(): time.sleep(5) self.producer_thread.start() diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 64ee724135..ba942967bc 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -211,9 +211,7 @@ def producer(queue, dataset, batch_size, device): queue = Queue(maxsize=50) device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' producer_thread = Thread( - target=producer, - args=(queue, dataset, cfg.policy.batch_size, device), - name='cuda_fetcher_producer' + target=producer, args=(queue, dataset, cfg.policy.batch_size, device), name='cuda_fetcher_producer' ) def _fetch(ctx: "OfflineRLContext"): @@ -263,6 +261,7 @@ def _fetch(ctx: "OfflineRLContext"): ) ctx.train_data = next(dataloader) # TODO apply data update (e.g. priority) in offline setting when necessary + return _fetch diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index 1e7eee0de2..5335a77529 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -95,6 +95,7 @@ def forward(self, x): class DecisionTransformer(nn.Module): + def __init__( self, state_dim, @@ -121,7 +122,7 @@ def __init__( self.embed_ln = nn.LayerNorm(h_dim) self.embed_timestep = nn.Embedding(max_timestep, h_dim) self.drop = nn.Dropout(drop_p) - + self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim)) @@ -140,9 +141,11 @@ def __init__( # discrete actions self.embed_action = torch.nn.Embedding(act_dim, h_dim) use_action_tanh = False # False for discrete actions - self.predict_action = nn.Sequential(*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))) + self.predict_action = nn.Sequential( + *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) + ) else: - blocks = [Block(h_dim, input_seq_len+1, n_heads, drop_p) for _ in range(n_blocks)] + blocks = [Block(h_dim, input_seq_len + 1, n_heads, drop_p) for _ in range(n_blocks)] self.state_encoder = state_encoder self.embed_rtg = nn.Sequential(nn.Linear(1, h_dim), nn.Tanh()) self.head = nn.Linear(h_dim, act_dim, bias=False) @@ -161,9 +164,8 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None): # stack rtg, states and actions and reshape sequence as # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) - t_p = torch.stack( - (returns_embeddings, state_embeddings, action_embeddings), dim=1 - ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) + t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings), + dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) h = self.embed_ln(t_p) # transformer and prediction h = self.transformer(h) @@ -183,20 +185,24 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None): state_embeddings = self.state_encoder( states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous() ) # (batch * block_size, h_dim) - state_embeddings = state_embeddings.reshape( - B, T, self.h_dim - ) # (batch, block_size, h_dim) + state_embeddings = state_embeddings.reshape(B, T, self.h_dim) # (batch, block_size, h_dim) returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32)) action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim) - token_embeddings = torch.zeros((B, T*3 - int(tar is None), self.h_dim), dtype=torch.float32, device=state_embeddings.device) - token_embeddings[:,::3,:] = returns_embeddings - token_embeddings[:,1::3,:] = state_embeddings - token_embeddings[:,2::3,:] = action_embeddings[:,-T + int(tar is None):,:] - - all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, B, dim=0) # batch_size, traj_length, h_dim + token_embeddings = torch.zeros( + (B, T * 3 - int(tar is None), self.h_dim), dtype=torch.float32, device=state_embeddings.device + ) + token_embeddings[:, ::3, :] = returns_embeddings + token_embeddings[:, 1::3, :] = state_embeddings + token_embeddings[:, 2::3, :] = action_embeddings[:, -T + int(tar is None):, :] + + all_global_pos_emb = torch.repeat_interleave( + self.global_pos_emb, B, dim=0 + ) # batch_size, traj_length, h_dim - position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :] + position_embeddings = torch.gather( + all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1) + ) + self.pos_emb[:, :token_embeddings.shape[1], :] t_p = token_embeddings + position_embeddings @@ -207,7 +213,7 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None): return_preds = None state_preds = None - action_preds = logits[:, 1::3, :] # only keep predictions from state_embeddings + action_preds = logits[:, 1::3, :] # only keep predictions from state_embeddings return state_preds, action_preds, return_preds @@ -227,7 +233,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)): blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name if pn.endswith('bias'): # all biases will not be decayed @@ -253,8 +259,14 @@ def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)): # create the pytorch optimizer object optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": weight_decay + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0 + }, ] optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) - return optimizer \ No newline at end of file + return optimizer diff --git a/ding/policy/dt.py b/ding/policy/dt.py index f769495a25..5e59378584 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -235,8 +235,9 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: if self.t[i] <= self.context_len: if 'state_mean' not in self._cfg: - timesteps[i] = min(self.t[i], - self._cfg.model.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device) + timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( + (1, 1), dtype=torch.int64 + ).to(self._device) else: timesteps[i] = self.timesteps[i, :self.context_len] states[i] = self.states[i, :self.context_len] @@ -244,8 +245,9 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] else: if 'state_mean' not in self._cfg: - timesteps[i] = min(self.t[i], - self._cfg.model.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device) + timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( + (1, 1), dtype=torch.int64 + ).to(self._device) else: timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] @@ -267,7 +269,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: else: act = torch.argmax(logits, axis=1).unsqueeze(1) for i in data_id: - self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t + self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t self.t[i] += 1 if self._cuda: diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index ea6e67bc67..def9809a66 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -151,9 +151,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: block_size = self.context_len done_idx = idx + block_size idx = done_idx - block_size - states = torch.as_tensor( - np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32 - ).view(block_size, -1) + states = torch.as_tensor(np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32).view(block_size, -1) actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long) rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32) timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64) From 550aa5854690667f649bf9b72079907ceb10f362 Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 14 Aug 2023 11:46:09 +0800 Subject: [PATCH 20/25] Reformat --- ding/entry/tests/test_serial_entry.py | 6 +- ding/example/dt.py | 3 +- ding/framework/middleware/__init__.py | 2 +- ding/framework/middleware/data_fetcher.py | 32 ++++++-- ding/model/template/dt.py | 9 ++- .../lunarlander_decision_transformer.py | 78 ------------------- 6 files changed, 37 insertions(+), 93 deletions(-) delete mode 100644 dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index d146262856..72840c565e 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -657,7 +657,8 @@ def test_discrete_dt(): from ding.config import compile_config from ding.model.template.dt import DecisionTransformer from ding.policy import DTPolicy - from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher_from_mem_c, offline_logger, termination_checker + from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \ + offline_data_fetcher_from_mem_c, offline_logger, termination_checker config = compile_config(config[0], create_cfg=config[1], auto=True) with task.start(async_mode=False, ctx=OfflineRLContext()): evaluator_env = BaseEnvManagerV2( @@ -685,9 +686,6 @@ def test_discrete_dt(): os.popen('rm -rf cartpole cartpole_dt') -test_discrete_dt() - - @pytest.mark.platformtest @pytest.mark.unittest def test_td3_bc(): diff --git a/ding/example/dt.py b/ding/example/dt.py index 10884c3ec7..74ea1525de 100644 --- a/ding/example/dt.py +++ b/ding/example/dt.py @@ -8,7 +8,8 @@ from ding.config import compile_config from ding.framework import task, ding_init from ding.framework.context import OfflineRLContext -from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger, termination_checker, final_ctx_saver +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \ + offline_data_fetcher, offline_logger, termination_checker, final_ctx_saver from ding.utils import set_pkg_seed from dizoo.box2d.lunarlander.envs.lunarlander_env import LunarLanderEnv from dizoo.box2d.lunarlander.config.lunarlander_dt_config import main_config, create_config diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index 74ee25950f..7f4e21a1df 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -4,4 +4,4 @@ from .ckpt_handler import CkptSaver from .distributer import ContextExchanger, ModelExchanger from .barrier import Barrier, BarrierRuntime -from .data_fetcher import offline_data_fetcher_from_mem_c \ No newline at end of file +from .data_fetcher import offline_data_fetcher_from_mem_c diff --git a/ding/framework/middleware/data_fetcher.py b/ding/framework/middleware/data_fetcher.py index ac13931082..08f7085663 100644 --- a/ding/framework/middleware/data_fetcher.py +++ b/ding/framework/middleware/data_fetcher.py @@ -22,11 +22,14 @@ def __new__(cls, *args, **kwargs): return super(offline_data_fetcher_from_mem_c, cls).__new__(cls) def __init__(self, cfg: EasyDict, dataset: Dataset): - stream = torch.cuda.Stream() + device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' + if device is not 'cpu': + stream = torch.cuda.Stream() def producer(queue, dataset, batch_size, device, event): torch.set_num_threads(4) - nonlocal stream + if device is not 'cpu': + nonlocal stream num_gpu = dist.get_world_size() rank = get_rank() idx_list = np.random.permutation(len(dataset)) @@ -35,7 +38,27 @@ def producer(queue, dataset, batch_size, device, event): temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) idx_iter = iter(temp_idx_list) - with torch.cuda.stream(stream): + if device is not 'cpu': + with torch.cuda.stream(stream): + while True: + if queue.full(): + time.sleep(0.1) + else: + data = [] + for _ in range(batch_size): + try: + data.append(dataset.__getitem__(next(idx_iter))) + except StopIteration: + del idx_iter + idx_list = np.random.permutation(len(dataset)) + idx_iter = iter(idx_list) + data.append(dataset.__getitem__(next(idx_iter))) + data = [[i[j] for i in data] for j in range(len(data[0]))] + data = [torch.stack(x).to(device) for x in data] + queue.put(data) + if event.is_set(): + break + else: while True: if queue.full(): time.sleep(0.1) @@ -50,14 +73,13 @@ def producer(queue, dataset, batch_size, device, event): idx_iter = iter(idx_list) data.append(dataset.__getitem__(next(idx_iter))) data = [[i[j] for i in data] for j in range(len(data[0]))] - data = [torch.stack(x).to(device) for x in data] + data = [torch.stack(x) for x in data] queue.put(data) if event.is_set(): break self.queue = Queue(maxsize=50) self.event = Event() - device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' self.producer_thread = Thread( target=producer, args=(self.queue, dataset, cfg.policy.batch_size, device, self.event), diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index 5335a77529..df3ad64524 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -126,7 +126,7 @@ def __init__( self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim)) - if state_encoder == None: + if state_encoder is None: self.state_encoder = None blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] self.embed_rtg = torch.nn.Linear(1, h_dim) @@ -154,7 +154,7 @@ def __init__( def forward(self, timesteps, states, actions, returns_to_go, tar=None): B, T = states.shape[0], states.shape[1] - if self.state_encoder == None: + if self.state_encoder is None: time_embeddings = self.embed_timestep(timesteps) # time embeddings are treated similar to positional embeddings @@ -254,8 +254,9 @@ def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)): inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) + assert len(param_dict.keys() - union_params) == 0,\ + "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) # create the pytorch optimizer object optim_groups = [ diff --git a/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py b/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py deleted file mode 100644 index cd3b2884e6..0000000000 --- a/dizoo/box2d/lunarlander/config/lunarlander_decision_transformer.py +++ /dev/null @@ -1,78 +0,0 @@ -from easydict import EasyDict -import torch -from copy import deepcopy - -lunarlander_dt_config = dict( - exp_name='data_dt/lunarlander_dt_1000eps_rtgt300_meel1000_seed0_debug', - env=dict( - env_name='LunarLander-v2', - collector_env_num=8, - evaluator_env_num=8, - n_evaluator_episode=8, - stop_value=200, - ), - policy=dict( - stop_value=200, - device='cuda', - env_name='LunarLander-v2', - rtg_target=300, # max target reward_to_go - max_eval_ep_len=1000, # max len of one episode # TODO - num_eval_ep=10, # num of evaluation episodes - batch_size=64, # training batch size - wt_decay=1e-4, - warmup_steps=10000, - num_updates_per_iter=100, - context_len=20, # TODO - n_blocks=3, - embed_dim=128, - n_heads=1, - dropout_p=0.1, - log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps', - model=dict( - state_dim=8, - act_dim=4, - n_blocks=3, - h_dim=128, - context_len=20, - n_heads=1, - drop_p=0.1, - continuous=False, # TODO - ), - discount_factor=0.999, - nstep=3, - learn=dict( - dataset_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO - learning_rate=1e-4, - target_update_freq=100, - kappa=1.0, - min_q_weight=4.0, - ), - collect=dict(unroll_len=1, ), - eval=dict(evaluator=dict(eval_freq=100, )), - other=dict( - eps=dict( - type='exp', - start=0.95, - end=0.1, - decay=10000, - ), replay_buffer=dict(replay_buffer_size=1000, ) - ), - ), -) -lunarlander_dt_config = EasyDict(lunarlander_dt_config) -main_config = lunarlander_dt_config -lunarlander_dt_create_config = dict( - env=dict( - type='lunarlander', - import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict(type='dt'), -) -lunarlander_dt_create_config = EasyDict(lunarlander_dt_create_config) -create_config = lunarlander_dt_create_config - -if __name__ == "__main__": - from ding.entry import serial_pipeline_dt, collect_demo_data, eval, serial_pipeline - config = deepcopy([main_config, create_config]) - serial_pipeline_dt(config, seed=0, max_train_iter=1000) From 1f8db9c7426c8b93d7c268bc3f9f91a313a72727 Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 14 Aug 2023 19:03:04 +0800 Subject: [PATCH 21/25] Reformat --- ding/entry/tests/test_serial_entry.py | 2 +- ding/framework/middleware/data_fetcher.py | 6 +++--- ding/model/template/dt.py | 4 ++-- ding/utils/data/dataset.py | 2 +- .../cartpole_balance_dreamer_config.py | 12 +++++++----- .../config/cheetah_run/cheetah_run_dreamer_config.py | 8 ++++---- dizoo/dmc2gym/config/dmc2gym_dreamer_config.py | 12 +++++++----- .../config/walker_walk/walker_walk_dreamer_config.py | 6 +++--- 8 files changed, 28 insertions(+), 24 deletions(-) diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index 72840c565e..8edbd3ee49 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -658,7 +658,7 @@ def test_discrete_dt(): from ding.model.template.dt import DecisionTransformer from ding.policy import DTPolicy from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \ - offline_data_fetcher_from_mem_c, offline_logger, termination_checker + offline_data_fetcher_from_mem_c, offline_logger, termination_checker config = compile_config(config[0], create_cfg=config[1], auto=True) with task.start(async_mode=False, ctx=OfflineRLContext()): evaluator_env = BaseEnvManagerV2( diff --git a/ding/framework/middleware/data_fetcher.py b/ding/framework/middleware/data_fetcher.py index 08f7085663..5eb97065e8 100644 --- a/ding/framework/middleware/data_fetcher.py +++ b/ding/framework/middleware/data_fetcher.py @@ -23,12 +23,12 @@ def __new__(cls, *args, **kwargs): def __init__(self, cfg: EasyDict, dataset: Dataset): device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' - if device is not 'cpu': + if device == 'cpu': stream = torch.cuda.Stream() def producer(queue, dataset, batch_size, device, event): torch.set_num_threads(4) - if device is not 'cpu': + if device == 'cpu': nonlocal stream num_gpu = dist.get_world_size() rank = get_rank() @@ -38,7 +38,7 @@ def producer(queue, dataset, batch_size, device, event): temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) idx_iter = iter(temp_idx_list) - if device is not 'cpu': + if device == 'cpu': with torch.cuda.stream(stream): while True: if queue.full(): diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index df3ad64524..221dabf973 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -255,8 +255,8 @@ def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)): union_params = decay | no_decay assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) assert len(param_dict.keys() - union_params) == 0,\ - "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) + "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) # create the pytorch optimizer object optim_groups = [ diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index def9809a66..79813fb0aa 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -532,7 +532,7 @@ def __init__(self, cfg: dict) -> None: trajectories_to_load = cfg.dataset.trajectories_per_buffer while not done: states, ac, ret, next_states, next_action, next_reward, terminal, indices = \ - frb.sample_transition_batch( batch_size=1, indices=[i]) + frb.sample_transition_batch(batch_size=1, indices=[i]) states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) obss.append(states) actions.append(ac[0]) diff --git a/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py b/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py index 1304c62352..66f7c7e2a4 100644 --- a/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py +++ b/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py @@ -26,11 +26,11 @@ policy=dict( cuda=cuda, # it is better to put random_collect_size in policy.other - random_collect_size=2500, + random_collect_size=2500, model=dict( obs_shape=(3, 64, 64), action_shape=1, - actor_dist = 'normal', + actor_dist='normal', ), learn=dict( lambda_=0.95, @@ -48,7 +48,7 @@ collect_dyn_sample=True, ), command=dict(), - eval=dict(evaluator=dict(eval_freq=5000, )), + eval=dict(evaluator=dict(eval_freq=5000, )), other=dict( # environment buffer replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60), @@ -56,7 +56,7 @@ ), world_model=dict( pretrain=100, - train_freq=2, + train_freq=2, cuda=cuda, model=dict( state_size=(3, 64, 64), # has to be specified @@ -88,4 +88,6 @@ cartpole_balance_create_config = EasyDict(cartpole_balance_create_config) if __name__ == '__main__': - serial_pipeline_dreamer((cartpole_balance_dreamer_config, cartpole_balance_create_config), seed=0, max_env_step=500000) \ No newline at end of file + serial_pipeline_dreamer( + (cartpole_balance_dreamer_config, cartpole_balance_create_config), seed=0, max_env_step=500000 + ) diff --git a/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py b/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py index 39c28bed93..32a43463e7 100644 --- a/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py +++ b/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py @@ -30,7 +30,7 @@ model=dict( obs_shape=(3, 64, 64), action_shape=6, - actor_dist = 'normal', + actor_dist='normal', ), learn=dict( lambda_=0.95, @@ -48,7 +48,7 @@ collect_dyn_sample=True, ), command=dict(), - eval=dict(evaluator=dict(eval_freq=5000, )), + eval=dict(evaluator=dict(eval_freq=5000, )), other=dict( # environment buffer replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60), @@ -56,7 +56,7 @@ ), world_model=dict( pretrain=100, - train_freq=2, + train_freq=2, cuda=cuda, model=dict( state_size=(3, 64, 64), # has to be specified @@ -88,4 +88,4 @@ cheetah_run_create_config = EasyDict(cheetah_run_create_config) if __name__ == '__main__': - serial_pipeline_dreamer((cheetah_run_dreamer_config, cheetah_run_create_config), seed=0, max_env_step=500000) \ No newline at end of file + serial_pipeline_dreamer((cheetah_run_dreamer_config, cheetah_run_create_config), seed=0, max_env_step=500000) diff --git a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py index b548ebfdd2..de8e09e3d8 100644 --- a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py @@ -26,11 +26,11 @@ policy=dict( cuda=cuda, # it is better to put random_collect_size in policy.other - random_collect_size=2500, + random_collect_size=2500, model=dict( obs_shape=(3, 64, 64), action_shape=1, - actor_dist = 'normal', + actor_dist='normal', ), learn=dict( lambda_=0.95, @@ -48,7 +48,7 @@ collect_dyn_sample=True, ), command=dict(), - eval=dict(evaluator=dict(eval_freq=5000, )), + eval=dict(evaluator=dict(eval_freq=5000, )), other=dict( # environment buffer replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60), @@ -56,7 +56,7 @@ ), world_model=dict( pretrain=100, - train_freq=2, + train_freq=2, cuda=cuda, model=dict( state_size=(3, 64, 64), # has to be specified @@ -88,4 +88,6 @@ cartpole_balance_create_config = EasyDict(cartpole_balance_create_config) if __name__ == '__main__': - serial_pipeline_dreamer((cartpole_balance_dreamer_config, cartpole_balance_create_config), seed=0, max_env_step=1000000) \ No newline at end of file + serial_pipeline_dreamer( + (cartpole_balance_dreamer_config, cartpole_balance_create_config), seed=0, max_env_step=1000000 + ) diff --git a/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py b/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py index ee8f350b51..16e76eac39 100644 --- a/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py +++ b/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py @@ -30,7 +30,7 @@ model=dict( obs_shape=(3, 64, 64), action_shape=6, - actor_dist = 'normal', + actor_dist='normal', ), learn=dict( lambda_=0.95, @@ -56,7 +56,7 @@ ), world_model=dict( pretrain=100, - train_freq=2, + train_freq=2, cuda=cuda, model=dict( state_size=(3, 64, 64), # has to be specified @@ -88,4 +88,4 @@ walker_walk_create_config = EasyDict(walker_walk_create_config) if __name__ == '__main__': - serial_pipeline_dreamer((walker_walk_dreamer_config, walker_walk_create_config), seed=0, max_env_step=500000) \ No newline at end of file + serial_pipeline_dreamer((walker_walk_dreamer_config, walker_walk_create_config), seed=0, max_env_step=500000) From 477640f4b031475beb9ced4ff3c8dbe9ef6d17ae Mon Sep 17 00:00:00 2001 From: luyudong Date: Thu, 17 Aug 2023 17:32:09 +0800 Subject: [PATCH 22/25] Reformat --- .../middleware/functional/data_processor.py | 16 +++++----------- ding/policy/dt.py | 2 ++ ding/utils/data/dataset.py | 3 ++- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index ba942967bc..7c17481e4e 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -238,7 +238,7 @@ def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable: - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data. """ # collate_fn is executed in policy now - dataloader = iter(DataLoader(dataset, batch_size=cfg.policy.batch_size, shuffle=True, collate_fn=lambda x: x)) + dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x) def _fetch(ctx: "OfflineRLContext"): """ @@ -250,16 +250,10 @@ def _fetch(ctx: "OfflineRLContext"): Output of ctx: - train_data (:obj:`List[Tensor]`): The fetched data batch. """ - nonlocal dataloader - try: - ctx.train_data = next(dataloader) - except StopIteration: - ctx.train_epoch += 1 - del dataloader - dataloader = iter( - DataLoader(dataset, batch_size=cfg.policy.batch_size, shuffle=True, collate_fn=lambda x: x) - ) - ctx.train_data = next(dataloader) + while True: + for i, data in enumerate(dataloader): + ctx.train_data = data + yield # TODO apply data update (e.g. priority) in offline setting when necessary return _fetch diff --git a/ding/policy/dt.py b/ding/policy/dt.py index 5e59378584..9b5ec00083 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -93,6 +93,8 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: self._learn_model.train() timesteps, states, actions, returns_to_go, traj_mask = data + if actions.dtype is not torch.long: + actions = actions.to(torch.long) action_target = torch.clone(actions).detach().to(self._device) # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 79813fb0aa..2006eee3bb 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -132,7 +132,8 @@ def __init__(self, cfg: dict) -> None: logging.warning("not found h5py package, please install it trough `pip install h5py ") sys.exit(1) data_path = cfg.policy.collect.get('data_path', None) - self.context_len = cfg.dataset.context_len + if 'dataset' in cfg: + self.context_len = cfg.dataset.context_len data = h5py.File(data_path, 'r') self._load_data(data) self._cal_statistics() From 22fc34fba4a1e05ad19fda96e1d2f1ff0c51cf9c Mon Sep 17 00:00:00 2001 From: luyudong Date: Fri, 18 Aug 2023 17:22:13 +0800 Subject: [PATCH 23/25] Reformat --- ding/framework/middleware/data_fetcher.py | 6 +++--- ding/utils/data/dataset.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ding/framework/middleware/data_fetcher.py b/ding/framework/middleware/data_fetcher.py index 5eb97065e8..76c9354e72 100644 --- a/ding/framework/middleware/data_fetcher.py +++ b/ding/framework/middleware/data_fetcher.py @@ -23,12 +23,12 @@ def __new__(cls, *args, **kwargs): def __init__(self, cfg: EasyDict, dataset: Dataset): device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' - if device == 'cpu': + if device != 'cpu': stream = torch.cuda.Stream() def producer(queue, dataset, batch_size, device, event): torch.set_num_threads(4) - if device == 'cpu': + if device != 'cpu': nonlocal stream num_gpu = dist.get_world_size() rank = get_rank() @@ -38,7 +38,7 @@ def producer(queue, dataset, batch_size, device, event): temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) idx_iter = iter(temp_idx_list) - if device == 'cpu': + if device != 'cpu': with torch.cuda.stream(stream): while True: if queue.full(): diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 2006eee3bb..ccd6ba08f4 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -134,6 +134,8 @@ def __init__(self, cfg: dict) -> None: data_path = cfg.policy.collect.get('data_path', None) if 'dataset' in cfg: self.context_len = cfg.dataset.context_len + else: + self.context_len = 0 data = h5py.File(data_path, 'r') self._load_data(data) self._cal_statistics() From 4f349b4c345fc2172f3b5e627def0960b03e3adb Mon Sep 17 00:00:00 2001 From: luyudong Date: Fri, 18 Aug 2023 19:45:43 +0800 Subject: [PATCH 24/25] Reformat --- ding/framework/middleware/data_fetcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ding/framework/middleware/data_fetcher.py b/ding/framework/middleware/data_fetcher.py index 76c9354e72..83595df4f0 100644 --- a/ding/framework/middleware/data_fetcher.py +++ b/ding/framework/middleware/data_fetcher.py @@ -30,11 +30,11 @@ def producer(queue, dataset, batch_size, device, event): torch.set_num_threads(4) if device != 'cpu': nonlocal stream - num_gpu = dist.get_world_size() + sbatch_size = batch_size * dist.get_world_size() rank = get_rank() idx_list = np.random.permutation(len(dataset)) temp_idx_list = [] - for i in range(len(dataset) // (batch_size * num_gpu)): + for i in range(len(dataset) // sbatch_size): temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) idx_iter = iter(temp_idx_list) From 0af7684abc3dac3fffde04ec2ba0ea93f1cc8633 Mon Sep 17 00:00:00 2001 From: luyudong Date: Fri, 18 Aug 2023 20:21:18 +0800 Subject: [PATCH 25/25] Reformat --- ding/model/template/dt.py | 2 +- ding/policy/dt.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ding/model/template/dt.py b/ding/model/template/dt.py index 221dabf973..da1e72f7d6 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/dt.py @@ -8,7 +8,7 @@ and its corresponding notebook: https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing -** the above colab notebook has a bug while applying masked_fill +** the above colab notebook has a bug while applying masked_fill which is fixed in the following code """ diff --git a/ding/policy/dt.py b/ding/policy/dt.py index 9b5ec00083..771f383bc5 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -37,7 +37,7 @@ class DTPolicy(Policy): batch_size=64, # training batch size wt_decay=1e-4, # decay weight in optimizer warmup_steps=10000, # steps for learning rate warmup - context_len=20, # length of transformer input + context_len=20, # length of transformer input learning_rate=1e-4, )