diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index 1e90351ee4..87d06d2d68 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -27,3 +27,4 @@ from .application_entry_drex_collect_data import drex_collecting_data from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream from .serial_entry_bco import serial_pipeline_bco +from .serial_entry_pc_mcts import serial_pipeline_pc_mcts diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py new file mode 100644 index 0000000000..671d45c76b --- /dev/null +++ b/ding/entry/serial_entry_pc_mcts.py @@ -0,0 +1,173 @@ +from typing import Union, Optional, Tuple +import os +import torch +from functools import partial +from tensorboardX import SummaryWriter +from copy import deepcopy +from torch.utils.data import DataLoader, Dataset +import pickle + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed + + +class MCTSPCDataset(Dataset): + + def __init__(self, data_dic, seq_len=4, hidden_state_noise=0): + self.observations = data_dic['obs'] + self.actions = data_dic['actions'] + self.hidden_states = data_dic['hidden_state'] + self.seq_len = seq_len + self.length = len(self.observations) - seq_len - 1 + self.hidden_state_noise = hidden_state_noise + + def __getitem__(self, idx): + """ + Assume the trajectory is: o1, h2, h3, h4 + """ + hidden_states = list(reversed(self.hidden_states[idx + 1:idx + self.seq_len + 1])) + actions = torch.tensor(list(reversed(self.actions[idx: idx + self.seq_len]))) + if self.hidden_state_noise > 0: + for i in range(len(hidden_states)): + hidden_states[i] += self.hidden_state_noise * torch.randn_like(hidden_states[i]) + return { + 'obs': self.observations[idx], + 'hidden_states': hidden_states, + 'action': actions + } + + def __len__(self): + return self.length + + +def load_mcts_datasets(path, seq_len, batch_size=32, hidden_state_noise=0): + with open(path, 'rb') as f: + dic = pickle.load(f) + tot_len = len(dic['obs']) + train_dic = {k: v[:-tot_len // 10] for k, v in dic.items()} + test_dic = {k: v[-tot_len // 10:] for k, v in dic.items()} + return DataLoader(MCTSPCDataset(train_dic, seq_len=seq_len, hidden_state_noise=hidden_state_noise), shuffle=True + , batch_size=batch_size), \ + DataLoader(MCTSPCDataset(test_dic, seq_len=seq_len, hidden_state_noise=hidden_state_noise), shuffle=True, + batch_size=batch_size) + + +def serial_pipeline_pc_mcts( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + max_iter=int(1e6), +) -> Union['Policy', bool]: # noqa + r""" + Overview: + Serial pipeline entry of procedure cloning with MCTS. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + Returns: + - policy (:obj:`Policy`): Converged policy. + - convergence (:obj:`bool`): whether il training is converged + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + + # Env, Policy + env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + # Random seed + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval']) + + # Main components + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + dataloader, test_dataloader = load_mcts_datasets(cfg.policy.expert_data_path, seq_len=cfg.policy.seq_len, + batch_size=cfg.policy.learn.batch_size, + hidden_state_noise=cfg.policy.learn.hidden_state_noise) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + + # ========== + # Main loop + # ========== + learner.call_hook('before_run') + stop = False + epoch_per_test = 10 + criterion = torch.nn.CrossEntropyLoss() + hidden_state_criterion = torch.nn.MSELoss() + for epoch in range(cfg.policy.learn.train_epoch): + # train + for i, train_data in enumerate(dataloader): + train_data['obs'] = train_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255. + learner.train(train_data) + if learner.train_iter >= max_iter: + stop = True + break + if epoch % 69 == 0: + policy._optimizer.param_groups[0]['lr'] /= 10 + if stop: + break + + if epoch % epoch_per_test == 0: + losses = [] + acces = [] + for _, test_data in enumerate(test_dataloader): + logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) + loss = criterion(logits, test_data['action'][:, -1].cuda()).item() + preds = torch.argmax(logits, dim=-1) + acc = torch.sum((preds == test_data['action'][:, -1].cuda())).item() / preds.shape[0] + + losses.append(loss) + acces.append(acc) + tb_logger.add_scalar('learner_iter/recurrent_test_loss', sum(losses) / len(losses), learner.train_iter) + tb_logger.add_scalar('learner_iter/recurrent_test_acc', sum(acces) / len(acces), learner.train_iter) + + losses = [] + acces = [] + for _, test_data in enumerate(dataloader): + logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) + loss = criterion(logits, test_data['action'][:, -1].cuda()).item() + preds = torch.argmax(logits, dim=-1) + acc = torch.sum((preds == test_data['action'][:, -1].cuda())).item() / preds.shape[0] + + losses.append(loss) + acces.append(acc) + tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) + tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) + + # Test for forward eval function. + # losses = [] + # mse_losses = [] + # acces = [] + # for _, test_data in enumerate(dataloader): + # test_hidden_states = torch.stack(test_data['hidden_states'], dim=1).float().cuda() + # logits, pred_hidden_states, hidden_state_embeddings = policy._model.test_forward_eval( + # test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255., + # test_hidden_states + # ) + # loss = criterion(logits, test_data['action'].cuda()).item() + # mse_loss = hidden_state_criterion(pred_hidden_states, hidden_state_embeddings).item() + # preds = torch.argmax(logits, dim=-1) + # acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] + # + # losses.append(loss) + # acces.append(acc) + # mse_losses.append(mse_loss) + # tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) + # tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) + # tb_logger.add_scalar('learner_iter/recurrent_train_mse_loss', sum(mse_losses) / len(mse_losses), learner.train_iter) + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) + learner.call_hook('after_run') + print('final reward is: {}'.format(reward)) + return policy, stop diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 11f7aa35b5..fb2223f3da 100644 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -22,4 +22,4 @@ from .madqn import MADQN from .vae import VanillaVAE from .decision_transformer import DecisionTransformer -from .procedure_cloning import ProcedureCloning +from .procedure_cloning import ProcedureCloningMCTS diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index a86e813933..51594faf7c 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -3,117 +3,180 @@ import torch.nn as nn from ding.utils import MODEL_REGISTRY, SequenceType from ding.torch_utils.network.transformer import Attention -from ding.torch_utils.network.nn_module import fc_block, build_normalization -from ..common import FCEncoder, ConvEncoder +from ..common import ConvEncoder -class Block(nn.Module): +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn - def __init__( - self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int, - feedforward_hidden: int, n_feedforward: int - ) -> None: + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, drop_p=0.): super().__init__() - self.n_att = n_att - self.n_feedforward = n_feedforward - self.attention_layer = [] - - self.norm_layer = [nn.LayerNorm(att_hidden)] * n_att - self.attention_layer.append(Attention(cnn_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) - for i in range(n_att - 1): - self.attention_layer.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) - - self.att_drop = nn.Dropout(drop_p) - - self.fc_blocks = [] - self.fc_blocks.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) - for i in range(n_feedforward - 1): - self.fc_blocks.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) - self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward) - self.mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T) - - def forward(self, x: torch.Tensor): - for i in range(self.n_att): - x = self.att_drop(self.attention_layer[i](x, self.mask)) - x = self.norm_layer[i](x) - for i in range(self.n_feedforward): - x = self.fc_blocks[i](x) - x = self.norm_layer[i + self.n_att](x) + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(drop_p), + nn.Linear(hidden_dim, dim), + nn.Dropout(drop_p) + ) + + def forward(self, x): + return self.net(x) + + +class Transformer(nn.Module): + def __init__(self, n_layer: int, n_attn: int, n_head: int, drop_p: float, max_T: int, n_ffn: int): + super().__init__() + self.layers = nn.ModuleList([]) + assert n_attn % n_head == 0 + dim_head = n_attn // n_head + for _ in range(n_layer): + self.layers.append(nn.ModuleList([ + PreNorm(n_attn, Attention(n_attn, dim_head, n_attn, n_head, nn.Dropout(drop_p))), + PreNorm(n_attn, FeedForward(n_attn, n_ffn, drop_p=drop_p)) + ])) + self.mask = nn.Parameter( + torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T), requires_grad=False + ) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x, mask=self.mask) + x + x = ff(x) + x return x -@MODEL_REGISTRY.register('pc') -class ProcedureCloning(nn.Module): +@MODEL_REGISTRY.register('pc_mcts') +class ProcedureCloningMCTS(nn.Module): def __init__( - self, - obs_shape: SequenceType, - action_dim: int, - cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256], - cnn_activation: Optional[nn.Module] = nn.ReLU(), - cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3], - cnn_stride: SequenceType = [1, 1, 1, 1, 1], - cnn_padding: Optional[SequenceType] = ['same', 'same', 'same', 'same', 'same'], - mlp_hidden_list: SequenceType = [256, 256], - mlp_activation: Optional[nn.Module] = nn.ReLU(), - att_heads: int = 8, - att_hidden: int = 128, - n_att: int = 4, - n_feedforward: int = 2, - feedforward_hidden: int = 256, - drop_p: float = 0.5, - augment: bool = True, - max_T: int = 17 + self, + obs_shape: SequenceType, + hidden_shape: SequenceType, + action_dim: int, + seq_len: int, + cnn_hidden_list: SequenceType = [128, 256, 512], + cnn_kernel_size: SequenceType = [8, 4, 3], + cnn_stride: SequenceType = [4, 2, 1], + cnn_padding: Optional[SequenceType] = [0, 0, 0], + hidden_state_cnn_hidden_list: SequenceType = [128, 256, 512], + hidden_state_cnn_kernel_size: SequenceType = [3, 3, 3], + hidden_state_cnn_stride: SequenceType = [1, 1, 1], + hidden_state_cnn_padding: Optional[SequenceType] = [1, 1, 1], + cnn_activation: Optional[nn.Module] = nn.ReLU(), + att_heads: int = 8, + att_hidden: int = 512, + n_att_layer: int = 4, + ffn_hidden: int = 512, + drop_p: float = 0., ) -> None: super().__init__() + self.obs_shape = obs_shape + self.hidden_shape = hidden_shape + self.seq_len = seq_len + max_T = seq_len + 1 - #Conv Encoder + # Conv Encoder self.embed_state = ConvEncoder( obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding ) - self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation) + self.embed_hidden = ConvEncoder( + hidden_shape, hidden_state_cnn_hidden_list, cnn_activation, hidden_state_cnn_kernel_size, + hidden_state_cnn_stride, hidden_state_cnn_padding + ) self.cnn_hidden_list = cnn_hidden_list - self.augment = augment - assert cnn_hidden_list[-1] == mlp_hidden_list[-1] - layers = [] - for i in range(n_att): - if i == 0: - layers.append(Attention(cnn_hidden_list[-1], att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) - else: - layers.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) - layers.append(build_normalization('LN')(att_hidden)) - for i in range(n_feedforward): - if i == 0: - layers.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) - else: - layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) - self.layernorm2 = build_normalization('LN')(feedforward_hidden) - - self.transformer = Block( - cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward - ) + assert cnn_hidden_list[-1] == att_hidden + self.transformer = Transformer(n_layer=n_att_layer, n_attn=att_hidden, n_head=att_heads, + drop_p=drop_p, max_T=max_T, n_ffn=ffn_hidden) - self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1]) + self.predict_hidden_state = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1]) self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim) - def forward(self, states: torch.Tensor, goals: torch.Tensor, - actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - B, T, _ = actions.shape + def _compute_embeddings(self, states: torch.Tensor, hidden_states: torch.Tensor): + B, T, *_ = hidden_states.shape - # shape: (B, h_dim) + # shape: (B, 1, h_dim) state_embeddings = self.embed_state(states).reshape(B, 1, self.cnn_hidden_list[-1]) - goal_embeddings = self.embed_state(goals).reshape(B, 1, self.cnn_hidden_list[-1]) - # shape: (B, context_len, h_dim) - actions_embeddings = self.embed_action(actions) - - h = torch.cat((state_embeddings, goal_embeddings, actions_embeddings), dim=1) + # shape: (B, T, h_dim) + if T > 0: + hidden_state_embeddings = self.embed_hidden(hidden_states.reshape(B * T, *hidden_states.shape[2:])) \ + .reshape(B, T, self.cnn_hidden_list[-1]) + else: + hidden_state_embeddings = None + return state_embeddings, hidden_state_embeddings + + def _compute_transformer(self, h): + B, T, *_ = h.shape h = self.transformer(h) - h = h.reshape(B, T + 2, self.cnn_hidden_list[-1]) + h = h.reshape(B, T, self.cnn_hidden_list[-1]) - goal_preds = self.predict_goal(h[:, 0, :]) + hidden_state_preds = self.predict_hidden_state(h[:, 0:-1, ...]) action_preds = self.predict_action(h[:, 1:, :]) + return hidden_state_preds, action_preds + + def forward(self, states: torch.Tensor, hidden_states: torch.Tensor) \ + -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # State is current observation. + # Hidden states is a sequence including [L, R, ...]. + # The shape of state and hidden state may be different. + B, T, *_ = hidden_states.shape + assert T == self.seq_len + state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) + if hidden_state_embeddings is not None: + h = torch.cat((state_embeddings, hidden_state_embeddings), dim=1) + else: + h = state_embeddings + hidden_state_preds, action_preds = self._compute_transformer(h) + + return hidden_state_preds, action_preds, hidden_state_embeddings.detach() \ + if hidden_state_embeddings is not None else None + + def forward_eval(self, states: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + batch_size = states.shape[0] + hidden_states = torch.zeros(batch_size, self.seq_len, *self.hidden_shape, dtype=states.dtype).to( + states.device) + embedding_mask = torch.zeros(1, self.seq_len, 1).to(states.device) + + state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) + + for i in range(self.seq_len): + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + hidden_state_embeddings, action_pred = self._compute_transformer(h) + embedding_mask[0, i, 0] = 1 + + if self.seq_len > 0: + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + else: + h = state_embeddings + hidden_state_embeddings, action_pred = self._compute_transformer(h) + + return action_pred[:, -1, :] + + def test_forward_eval(self, states: torch.Tensor, hidden_states: torch.Tensor) -> Tuple: + # Action pred in this function is supposed to be identical in training phase. + with torch.no_grad(): + embedding_mask = torch.zeros(1, self.seq_len, 1).to(states.device) + state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) + + for i in range(self.seq_len): + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + _, action_pred = self._compute_transformer(h) + embedding_mask[0, i, 0] = 1 + + if self.seq_len > 0: + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + else: + h = state_embeddings + pred_hidden_state_embeddings, action_pred = self._compute_transformer(h) - return goal_preds, action_preds + return action_pred, pred_hidden_state_embeddings, hidden_state_embeddings diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index e169ec2cee..534d792f37 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -1,34 +1,33 @@ import torch import pytest -import numpy as np -from itertools import product -from ding.model.template import ProcedureCloning -from ding.torch_utils import is_differentiable -from ding.utils import squeeze +from ding.model.template import ProcedureCloningMCTS B = 4 T = 15 -obs_shape = [(64, 64, 3)] -action_dim = [9] -obs_embeddings = 256 -args = list(product(*[obs_shape, action_dim])) +obs_shape = (3, 64, 64) +hidden_shape = (64, 9, 9) +action_dim = 6 +obs_embeddings = 512 @pytest.mark.unittest -@pytest.mark.parametrize('obs_shape, action_dim', args) -class TestProcedureCloning: - - def test_procedure_cloning(self, obs_shape, action_dim): - inputs = { - 'states': torch.randn(B, *obs_shape), - 'goals': torch.randn(B, *obs_shape), - 'actions': torch.randn(B, T, action_dim) - } - model = ProcedureCloning(obs_shape=obs_shape, action_dim=action_dim) - - print(model) - - goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) - assert goal_preds.shape == (B, obs_embeddings) - assert action_preds.shape == (B, T + 1, action_dim) +def test_procedure_cloning(): + inputs = { + 'states': torch.randn(B, *obs_shape), + 'hidden_states': torch.randn(B, T, *hidden_shape), + 'actions': torch.randn(B, action_dim) + } + model = ProcedureCloningMCTS(obs_shape=obs_shape, hidden_shape=hidden_shape, seq_len=T, action_dim=action_dim) + + print(model) + + hidden_state_preds, action_preds, target_hidden_state = model(inputs['states'], inputs['hidden_states']) + assert hidden_state_preds.shape == (B, T, obs_embeddings) + assert action_preds.shape == (B, action_dim) + + action_eval = model.forward_eval(inputs['states']) + assert action_eval.shape == (B, action_dim) + + hidden_state_preds_new, _, _ = model(inputs['states'], torch.zeros_like(inputs['hidden_states'])) + assert torch.sum(torch.abs(hidden_state_preds_new[:, 0, :] - hidden_state_preds[:, 0, :])).item() < 1e-9 diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 5938334022..b599c0a579 100644 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -44,6 +44,7 @@ from .bc import BehaviourCloningPolicy from .ibc import IBCPolicy +from .pc import ProcedureCloningPolicyMCTS # new-type policy from .ppof import PPOFPolicy diff --git a/ding/policy/pc.py b/ding/policy/pc.py new file mode 100644 index 0000000000..b6865a0e8f --- /dev/null +++ b/ding/policy/pc.py @@ -0,0 +1,254 @@ +import math +from typing import List, Dict, Any, Tuple +from collections import namedtuple + +import torch +import torch.nn as nn +from torch.optim import Adam, SGD, AdamW +from torch.optim.lr_scheduler import LambdaLR +from easydict import EasyDict + +from ding.policy import Policy +from ding.model import model_wrap +from ding.torch_utils import to_device +from ding.utils import EasyTimer +from ding.utils.data import default_collate, default_decollate +from ding.rl_utils import get_nstep_return_data, get_train_sample +from ding.utils import POLICY_REGISTRY + + +class BatchCELoss(nn.Module): + def __init__(self, seq, mask): + super(BatchCELoss, self).__init__() + self.ce = nn.CrossEntropyLoss() + self.nce = nn.CrossEntropyLoss(reduction='none') + self.mask = mask + self.seq = seq + self.masked_ratio = 0 + + def forward(self, pred_y, target_y): + if not self.seq: + return self.ce(pred_y[:, -1, :], target_y[:, -1]), 1 + if not self.mask: + losses = 0 + for i in range(target_y.shape[1]): + losses += self.ce(pred_y[:, i, :], target_y[:, i]) + return losses, target_y.shape[1] + else: + eqs = [] + losses = 0 + cnt = 0 + + cur_loss = self.nce(pred_y[:, 0, :], target_y[:, 0]) + losses += torch.sum(cur_loss) + cnt += target_y.shape[0] + eqs.append((torch.argmax(pred_y[:, 0, :], dim=-1) == target_y[:, 0])) + + for i in range(1, target_y.shape[1]): + cur_loss = self.nce(pred_y[:, i, :], target_y[:, i]) + losses += torch.sum(cur_loss * eqs[-1]) + cnt += torch.sum(eqs[-1]) + # Update eqs + eqs.append((torch.argmax(pred_y[:, i, :], dim=-1) == target_y[:, i])) + eqs[-1] = eqs[-1] & eqs[-2] + return losses / cnt, cnt / target_y.shape[0] + + +@POLICY_REGISTRY.register('pc_mcts') +class ProcedureCloningPolicyMCTS(Policy): + config = dict( + type='pc_mcts', + cuda=True, + on_policy=False, + continuous=False, + learn=dict( + multi_gpu=False, + update_per_collect=1, + batch_size=32, + learning_rate=1e-5, + lr_decay=False, + decay_epoch=30, + decay_rate=0.1, + warmup_lr=1e-4, + warmup_epoch=3, + optimizer='SGD', + momentum=0.9, + weight_decay=1e-4, + ce_label_smooth=False, + show_accuracy=False, + tanh_mask=False, # if actions always converge to 1 or -1, use this. + ), + collect=dict( + unroll_len=1, + noise=False, + noise_sigma=0.2, + noise_range=dict( + min=-0.5, + max=0.5, + ), + ), + eval=dict(), + other=dict(replay_buffer=dict(replay_buffer_size=10000, )), + ) + + def default_model(self) -> Tuple[str, List[str]]: + return 'pc_mcts', ['ding.model.template.procedure_cloning'] + + def _init_learn(self): + assert self._cfg.learn.optimizer in ['SGD', 'Adam'] + if self._cfg.learn.optimizer == 'SGD': + self._optimizer = SGD( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + weight_decay=self._cfg.learn.weight_decay, + momentum=self._cfg.learn.momentum + ) + elif self._cfg.learn.optimizer == 'Adam': + if self._cfg.learn.weight_decay is None: + self._optimizer = Adam( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + ) + else: + self._optimizer = AdamW( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + weight_decay=self._cfg.learn.weight_decay + ) + if self._cfg.learn.lr_decay: + + def lr_scheduler_fn(epoch): + if epoch <= self._cfg.learn.warmup_epoch: + return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate + else: + ratio = (epoch - self._cfg.learn.warmup_epoch) // self._cfg.learn.decay_epoch + return math.pow(self._cfg.learn.decay_rate, ratio) + + self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn) + self._timer = EasyTimer(cuda=True) + self._learn_model = model_wrap(self._model, 'base') + self._learn_model.reset() + + # self._hidden_state_loss = nn.MSELoss() + self._hidden_state_loss = nn.L1Loss() + self._action_loss = BatchCELoss(seq=self._cfg.seq_action, mask=self._cfg.mask_seq_action) + + def _forward_learn(self, data): + if self._cuda: + data = to_device(data, self._device) + self._learn_model.train() + with self._timer: + obs, hidden_states, action = data['obs'], data['hidden_states'], data['action'] + zero_hidden_len = len(hidden_states) == 0 + if not zero_hidden_len: + hidden_states = torch.stack(hidden_states, dim=1).float() + else: + hidden_states = to_device(torch.empty(obs.shape[0], 0, *self._learn_model.hidden_shape), self._device) + pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs, hidden_states) + if zero_hidden_len: + hidden_state_loss = torch.tensor(0.) + else: + hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) + action_loss, action_number = self._action_loss(pred_action, action) + loss = hidden_state_loss + action_loss + forward_time = self._timer.value + + with self._timer: + self._optimizer.zero_grad() + loss.backward() + backward_time = self._timer.value + + with self._timer: + if self._cfg.learn.multi_gpu: + self.sync_gradients(self._learn_model) + sync_time = self._timer.value + + self._optimizer.step() + cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups] + cur_lr = sum(cur_lr) / len(cur_lr) + return { + 'cur_lr': cur_lr, + 'total_loss': loss.item(), + 'hidden_state_loss': hidden_state_loss.item(), + 'action_loss': action_loss.item(), + 'forward_time': forward_time, + 'backward_time': backward_time, + 'action_number': action_number, + 'sync_time': sync_time, + } + + def _monitor_vars_learn(self): + return ['cur_lr', 'total_loss', 'hidden_state_loss', 'action_loss', + 'forward_time', 'backward_time', 'sync_time', 'action_number'] + + def _init_eval(self): + self._eval_model = model_wrap(self._model, wrapper_name='base') + self._eval_model.reset() + + def _forward_eval(self, data): + data_id = list(data.keys()) + values = list(data.values()) + data = [{'obs': v['observation']} for v in values] + data = default_collate(data) + + if self._cuda: + data = to_device(data, self._device) + self._eval_model.eval() + with torch.no_grad(): + output = self._eval_model.forward_eval(data['obs'].permute(0, 3, 1, 2) / 255.) + output = torch.argmax(output, dim=-1) + if self._cuda: + output = to_device(output, 'cpu') + output = {'action': output} + output = default_decollate(output) + # TODO why this bug? + output = [{'action': o['action'].item()} for o in output] + res = {i: d for i, d in zip(data_id, output)} + return res + + def _init_collect(self) -> None: + pass + + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: + pass + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + r""" + Overview: + Generate dict type transition data from inputs. + Arguments: + - obs (:obj:`Any`): Env observation + - model_output (:obj:`dict`): Output of collect model, including at least ['action'] + - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ + (here 'obs' indicates obs after env step). + Returns: + - transition (:obj:`dict`): Dict type transition data. + """ + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'action': model_output['action'], + 'reward': timestep.reward, + 'done': timestep.done, + } + return EasyDict(transition) + + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ + or some continuous transitions(DRQN). + Arguments: + - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ + format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`dict`): The list of training samples. + + .. note:: + We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ + And the user can customize the this data processing procecure by overriding this two methods and collector \ + itself. + """ + data = get_nstep_return_data(data, 1, 1) + return get_train_sample(data, self._unroll_len) diff --git a/ding/worker/collector/interaction_serial_evaluator.py b/ding/worker/collector/interaction_serial_evaluator.py index 3c5857c869..57b74e3d69 100644 --- a/ding/worker/collector/interaction_serial_evaluator.py +++ b/ding/worker/collector/interaction_serial_evaluator.py @@ -245,7 +245,10 @@ def eval( if self._cfg.figure_path is not None: self._env.enable_save_figure(env_id, self._cfg.figure_path) self._policy.reset([env_id]) - reward = t.info['eval_episode_return'] + if 'final_eval_reward' in t.info.keys(): + reward = t.info['final_eval_reward'] + else: + reward = t.info['eval_episode_return'] if 'episode_info' in t.info: eval_monitor.update_info(env_id, t.info['episode_info']) eval_monitor.update_reward(env_id, reward) diff --git a/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py new file mode 100644 index 0000000000..0c7a93bb4a --- /dev/null +++ b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py @@ -0,0 +1,78 @@ +from easydict import EasyDict + +seq_len = 4 +qbert_pc_mcts_config = dict( + exp_name='pong_pc_mcts_seed0', + env=dict( + manager=dict( + episode_num=float('inf'), + max_retry=5, + step_timeout=None, + auto_reset=True, + reset_timeout=None, + retry_type='reset', + retry_waiting_time=0.1, + shared_memory=False, + copy_on_get=True, + context='fork', + wait_num=float('inf'), + step_wait_timeout=None, + connect_timeout=60, + reset_inplace=False, + cfg_type='SyncSubprocessEnvManagerDict', + type='subprocess', + ), + dqn_expert_data=False, + cfg_type='AtariLightZeroEnvDict', + collector_env_num=8, + evaluator_env_num=3, + n_evaluator_episode=3, + env_name='PongNoFrameskip-v4', + stop_value=20, + collect_max_episode_steps=10800, + eval_max_episode_steps=108000, + frame_skip=4, + obs_shape=[12, 96, 96], + episode_life=True, + gray_scale=False, + cvt_string=False, + game_wrapper=True, + ), + policy=dict( + cuda=True, + expert_data_path='pong-v4-expert.pkl', + seq_len=seq_len, + seq_action=True, + mask_seq_action=False, + model=dict( + obs_shape=[3, 96, 96], + hidden_shape=[64, 6, 6], + action_dim=6, + seq_len=seq_len, + ), + learn=dict( + batch_size=32, + learning_rate=5e-4, + learner=dict(hook=dict(save_ckpt_after_iter=1000)), + train_epoch=100, + hidden_state_noise=0, + ), + eval=dict(evaluator=dict(eval_freq=40, )) + ), +) +qbert_pc_mcts_config = EasyDict(qbert_pc_mcts_config) +main_config = qbert_pc_mcts_config +qbert_pc_mcts_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='pc_mcts'), +) +qbert_pc_mcts_create_config = EasyDict(qbert_pc_mcts_create_config) +create_config = qbert_pc_mcts_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_pc_mcts + serial_pipeline_pc_mcts([main_config, create_config], seed=0) diff --git a/dizoo/atari/config/serial/qbert/qbert_pc_mcts_config.py b/dizoo/atari/config/serial/qbert/qbert_pc_mcts_config.py new file mode 100644 index 0000000000..099e18c917 --- /dev/null +++ b/dizoo/atari/config/serial/qbert/qbert_pc_mcts_config.py @@ -0,0 +1,44 @@ +from easydict import EasyDict + +qbert_pc_mcts_config = dict( + exp_name='qbert_pc_mcts_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=1000000, + env_id='Qbert-v4', + ), + policy=dict( + cuda=True, + expert_data_path='pong_expert/ez_pong_seed0.pkl', + model=dict( + obs_shape=[3, 96, 96], + hidden_shape=[32, 8, 8], + action_shape=6, + ), + learn=dict( + batch_size=64, + learning_rate=0.01, + learner=dict(hook=dict(save_ckpt_after_iter=1000)), + train_epoch=20, + ), + eval=dict(evaluator=dict(eval_freq=40, )) + ), +) +qbert_pc_mcts_config = EasyDict(qbert_pc_mcts_config) +main_config = qbert_pc_mcts_config +qbert_pc_mcts_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='pc_mcts'), +) +qbert_pc_mcts_create_config = EasyDict(qbert_pc_mcts_create_config) +create_config = qbert_pc_mcts_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_pc_mcts + serial_pipeline_pc_mcts([main_config, create_config], seed=0)