From 061ebc3410fa349fe8630d2caa8dfb28c6f4d994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 4 Oct 2023 19:47:58 +0800 Subject: [PATCH 1/3] doc(whl): add code doc for LT,DT,PC,BC models --- ding/model/template/bc.py | 32 ++- ding/model/template/decision_transformer.py | 190 ++++++++++++++++-- ding/model/template/language_transformer.py | 49 ++++- ding/model/template/procedure_cloning.py | 146 ++++++++++++-- .../template/tests/test_procedure_cloning.py | 3 - 5 files changed, 368 insertions(+), 52 deletions(-) diff --git a/ding/model/template/bc.py b/ding/model/template/bc.py index 4568e3ce1c..84499bc717 100644 --- a/ding/model/template/bc.py +++ b/ding/model/template/bc.py @@ -1,9 +1,8 @@ -from typing import Union, Optional, Dict, Callable, List +from typing import Union, Optional, Dict import torch import torch.nn as nn from easydict import EasyDict -from ding.torch_utils import get_lstm from ding.utils import MODEL_REGISTRY, SequenceType, squeeze from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \ MultiHead, RegressionHead, ReparameterizationHead @@ -11,6 +10,12 @@ @MODEL_REGISTRY.register('discrete_bc') class DiscreteBC(nn.Module): + r""" + Overview: + The DiscreteBC network. + Interfaces: + ``__init__``, ``forward`` + """ def __init__( self, @@ -36,7 +41,7 @@ def __init__( - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network. - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ - if ``None`` then default set it to ``nn.ReLU()`` + if ``None`` then default set it to ``nn.ReLU()``. - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ ``ding.torch_utils.fc_block`` for more details. """ @@ -127,7 +132,7 @@ def __init__( ) -> None: """ Overview: - Initailize the ContinuousBC Model according to input arguments. + Initialize the ContinuousBC Model according to input arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \ @@ -173,14 +178,27 @@ def __init__( ) ) - def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict: """ Overview: - The unique execution (forward) method of ContinuousBC method. + The unique execution (forward) method of ContinuousBC. Arguments: - inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor. Returns: - - output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space. + - output (:obj:`Dict`): Output dict data, including different key-values among distinct action_space. + + Examples (Regression): + >>> model = ContinuousBC(32, 6, action_space='regression') + >>> inputs = torch.randn(4, 32) + >>> outputs = model(inputs) + >>> assert isinstance(outputs, dict) and outputs['action'].shape == torch.Size([4, 6]) + + Examples (Reparameterization): + >>> model = ContinuousBC(32, 6, action_space='reparameterization') + >>> inputs = torch.randn(4, 32) + >>> outputs = model(inputs) + >>> assert isinstance(outputs, dict) and outputs['logit'][0].shape == torch.Size([4, 6])\ + and outputs['logit'][1].shape == torch.Size([4, 6]) """ if self.action_space == 'regression': x = self.actor(inputs) diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py index 73330da25f..988cedef8a 100644 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -13,14 +13,32 @@ """ import math +from typing import Union, Optional, Tuple + import torch import torch.nn as nn import torch.nn.functional as F +from ding.utils import SequenceType class MaskedCausalAttention(nn.Module): - - def __init__(self, h_dim, max_T, n_heads, drop_p): + r""" + Overview: + The implementation of masked causal attention in decision transformer. + Interfaces: + ``__init__``, ``forward`` + """ + + def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: + """ + Overview: + Initialize the MaskedCausalAttention Model according to input arguments. + Arguments: + - h_dim (:obj:`int`): The dimension of hidden states, such as 128. + - max_T (:obj:`int`): The max context length of the attention, such as 6. + - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. + - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. + """ super().__init__() self.n_heads = n_heads @@ -42,7 +60,22 @@ def __init__(self, h_dim, max_T, n_heads, drop_p): # during backpropagation self.register_buffer('mask', mask) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + MaskedCausalAttention forward computation graph, input a sequence tensor \ + and return a tensor with the same shape. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + Returns: + - out (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. + + Examples: + >>> inputs = torch.randn(2, 4, 64) + >>> model = MaskedCausalAttention(64, 5, 4, 0.1) + >>> outputs = model(inputs) + >>> assert outputs.shape == torch.Size([2, 4, 64]) + """ B, T, C = x.shape # batch size, seq length, h_dim * n_heads N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim @@ -70,8 +103,23 @@ def forward(self, x): class Block(nn.Module): - - def __init__(self, h_dim, max_T, n_heads, drop_p): + """ + Overview: + The implementation of a transformer block in decision transformer. + Interfaces: + ``__init__``, ``forward`` + """ + + def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: + """ + Overview: + Initialize the Block Model according to input arguments. + Arguments: + - h_dim (:obj:`int`): The dimension of hidden states, such as 128. + - max_T (:obj:`int`): The max context length of the attention, such as 6. + - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. + - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. + """ super().__init__() self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) self.mlp = nn.Sequential( @@ -83,7 +131,22 @@ def __init__(self, h_dim, max_T, n_heads, drop_p): self.ln1 = nn.LayerNorm(h_dim) self.ln2 = nn.LayerNorm(h_dim) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward computation graph of the decision transformer block, input a sequence tensor \ + and return a tensor with the same shape. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + Returns: + - output (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. + + Examples: + >>> inputs = torch.randn(2, 4, 64) + >>> model = Block(64, 5, 4, 0.1) + >>> outputs = model(inputs) + >>> outputs.shape == torch.Size([2, 4, 64]) + """ # Attention -> LayerNorm -> MLP -> LayerNorm x = x + self.attention(x) # residual x = self.ln1(x) @@ -95,20 +158,42 @@ def forward(self, x): class DecisionTransformer(nn.Module): + """ + Overview: + The implementation of decision transformer. + Interfaces: + ``__init__``, ``forward``, ``configure_optimizers`` + """ def __init__( self, - state_dim, - act_dim, - n_blocks, - h_dim, - context_len, - n_heads, - drop_p, - max_timestep=4096, - state_encoder=None, - continuous=False + state_dim: Union[int, SequenceType], + act_dim: int, + n_blocks: int, + h_dim: int, + context_len: int, + n_heads: int, + drop_p: float, + max_timestep: int = 4096, + state_encoder: Optional[nn.Module] = None, + continuous: bool = False ): + """ + Overview: + Initialize the DecisionTransformer Model according to input arguments. + Arguments: + - obs_shape (:obj:`Union[int, SequenceType]`): Dimension of state, such as 128, (4, 84, 84). + - act_dim (:obj:`int`): The dimension of actions, such as 6. + - n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3. + - h_dim (:obj:`int`): The dimension of hidden states, such as 128. + - context_len (:obj:`int`): The max context length of the attention, such as 6. + - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. + - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. + - max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096. + - state_encoder (:obj:`Optional[nn.Module]`): The encoder to pre-process the given input. If it is set to \ + None, the raw state will be pushed into the transformer. + - continuous (:obj:`bool`): Whether the action space is continuous, defaults to be ``False``. + """ super().__init__() self.state_dim = state_dim @@ -152,7 +237,60 @@ def __init__( self.embed_action = nn.Sequential(nn.Embedding(act_dim, h_dim), nn.Tanh()) self.transformer = nn.Sequential(*blocks) - def forward(self, timesteps, states, actions, returns_to_go, tar=None): + def forward( + self, + timesteps: torch.Tensor, + states: torch.Tensor, + actions: torch.Tensor, + returns_to_go: torch.Tensor, + tar: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation graph of the decision transformer, input a sequence tensor \ + and return a tensor with the same shape. + Arguments: + - timesteps (:obj:`torch.Tensor`): The timestep for input sequence. + - states (:obj:`torch.Tensor`): The sequence of states. + - actions (:obj:`torch.Tensor`): The sequence of actions. + - returns_to_go (:obj:`torch.Tensor`): The sequence of return-to-go. + - tar (:obj:`Optional[int]`): The targe index. + Returns: + - output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains three tensors, \ + they are correspondingly the predicted states, predicted actions and predicted return-to-go. + + Examples: + >>> B, T = 4, 6 + >>> state_dim = 3 + >>> act_dim = 2 + >>> DT_model = DecisionTransformer(\ + state_dim=state_dim,\ + act_dim=act_dim,\ + n_blocks=3,\ + h_dim=8,\ + context_len=T,\ + n_heads=2,\ + drop_p=0.1,\ + ) + + >>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T + >>> states = torch.randn([B, T, state_dim]) # B x T x state_dim + + >>> actions = torch.randint(0, act_dim, [B, T, 1]) + >>> action_target = torch.randint(0, act_dim, [B, T, 1]) + >>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float() + + >>> traj_mask = torch.ones([B, T], dtype=torch.long) # B x T + >>> actions = actions.squeeze(-1) + + >>> state_preds, action_preds, return_preds = DT_model.forward(\ + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go\ + ) + + >>> assert state_preds.shape == torch.Size([B, T, state_dim]) + >>> assert return_preds.shape == torch.Size([B, T, 1]) + >>> assert action_preds.shape == torch.Size([B, T, act_dim]) + """ B, T = states.shape[0], states.shape[1] if self.state_encoder is None: time_embeddings = self.embed_timestep(timesteps) @@ -217,12 +355,20 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None): return state_preds, action_preds, return_preds - def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)): + def configure_optimizers( + self, weight_decay: float, learning_rate: float, betas: Tuple[float, float] = (0.9, 0.95) + ) -> torch.optim.optimizer.Optimizer: """ - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. + Overview: + This function returns an optimizer given the input arguments. \ + We are separating out all parameters of the model into two buckets: those that will experience \ + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + Arguments: + - weight_decay (:obj:`float`): The weigh decay of the optimizer. + - learning_rate (:obj:`float`): The learning rate of the optimizer. + - betas (:obj:`Tuple[float, float]`): The betas for Adam optimizer. + Outputs: + - optimizer (:obj:`torch.optim.optimizer.Optimizer`): The desired optimizer. """ # separate out all parameters to those that will and won't experience regularizing weight decay diff --git a/ding/model/template/language_transformer.py b/ding/model/template/language_transformer.py index 521d365376..b9b52f14f5 100644 --- a/ding/model/template/language_transformer.py +++ b/ding/model/template/language_transformer.py @@ -1,16 +1,23 @@ +from typing import List, Dict import torch - -from ding.utils import MODEL_REGISTRY from torch import nn + try: from transformers import AutoTokenizer, AutoModelForTokenClassification except ImportError: from ditk import logging logging.warning("not found transformer, please install it using: pip install transformers") +from ding.utils import MODEL_REGISTRY @MODEL_REGISTRY.register('language_transformer') class LanguageTransformer(nn.Module): + r""" + Overview: + The LanguageTransformer network. Download a pre-trained language model and add head on it. + Interfaces: + ``__init__``, ``forward`` + """ def __init__( self, @@ -19,6 +26,17 @@ def __init__( embedding_size: int = 128, freeze_encoder: bool = True ) -> None: + """ + Overview: + Init the LanguageTransformer Model according to input arguments. + Arguments: + - model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased". + - add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \ + ``False``. + - embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128. + - freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \ + defaults to be ``True``. + """ super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForTokenClassification.from_pretrained(model_name) @@ -29,7 +47,7 @@ def __init__( param.requires_grad = False if add_linear: - # Add an additional small, adjustable linear layer on top of BERT tuned through RL + # Add a small, adjustable linear layer on top of language model tuned through RL self.embedding_size = embedding_size self.linear = nn.Linear( self.model.config.hidden_size, embedding_size @@ -54,7 +72,30 @@ def _calc_embedding(self, x: list) -> torch.Tensor: return sentence_embedding - def forward(self, train_samples: list, candidate_samples: list) -> dict: + def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict: + """ + Overview: + LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores. + Arguments: + - train_samples (:obj:`List[str]`): One list of strings. + - candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores. + Returns: + - output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \ + corresponding ``torch.distributions.Categorical`` object. + + Examples: + >>> test_pids = [1] + >>> cand_pids = [0, 2, 4] + >>> problems = [ \ + "This is problem 0", "This is the first question", "Second problem is here", "Another problem", \ + "This is the last problem" \ + ] + >>> ctxt_list = [problems[pid] for pid in test_pids] + >>> cands_list = [problems[pid] for pid in cand_pids] + >>> model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256) + >>> scores = model(ctxt_list, cands_list) + >>> assert scores.shape == (1, 3) + """ prompt_embedding = self._calc_embedding(train_samples) cands_embedding = self._calc_embedding(candidate_samples) scores = torch.mm(prompt_embedding, cands_embedding.t()) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index d0d6ffcbd0..9c22c5ee6b 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -9,12 +9,30 @@ from ..common import FCEncoder, ConvEncoder -class Block(nn.Module): +class PCTransformer(nn.Module): + """ + Overview: + The transformer block for neural network of algorithms related to Procedure cloning (PC). + Interfaces: + ``__init__``, ``forward``. + """ def __init__( self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int, feedforward_hidden: int, n_feedforward: int ) -> None: + """ + Overview: + Initialize the procedure cloning transformer model according to corresponding input arguments. + Arguments: + - cnn_hidden (:obj:`int`): The last channel dimension of CNN encoder, such as 32. + - att_hidden (:obj:`int`): The dimension of attention blocks, such as 32. + - att_heads (:obj:`int`): The number of heads in attention blocks, such as 4. + - drop_p (:obj:`float`): The drop out rate of attention, such as 0.5. + - max_T (:obj:`int`): The sequence length of procedure cloning, such as 4. + - n_attn (:obj:`int`): The number of attention layers, such as 4. + - feedforward_hidden (:obj:`int`): The number of feedforward layers, such as 4. + """ super().__init__() self.n_att = n_att self.n_feedforward = n_feedforward @@ -34,7 +52,20 @@ def __init__( self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward) self.mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + The unique execution (forward) method of PCTransformer. + Arguments: + - x (:obj:`torch.Tensor`): Sequential data of several hidden states. + Returns: + - output (:obj:`torch.Tensor`): A tensor with the same shape as the input. + Examples: + >>> model = PCTransformer(128, 128, 8, 0, 16, 2, 128, 2) + >>> h = torch.randn((2, 16, 128)) + >>> h = model(h) + >>> assert h.shape == torch.Size([2, 16, 128]) + """ for i in range(self.n_att): x = self.att_drop(self.attention_layer[i](x, self.mask)) x = self.norm_layer[i](x) @@ -46,37 +77,63 @@ def forward(self, x: torch.Tensor): @MODEL_REGISTRY.register('pc_mcts') class ProcedureCloningMCTS(nn.Module): + """ + Overview: + The neural network of algorithms related to Procedure cloning (PC). + Interfaces: + ``__init__``, ``forward``. + """ def __init__( self, obs_shape: SequenceType, action_dim: int, cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256], - cnn_activation: Optional[nn.Module] = nn.ReLU(), + cnn_activation: nn.Module = nn.ReLU(), cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3], cnn_stride: SequenceType = [1, 1, 1, 1, 1], - cnn_padding: Optional[SequenceType] = [1, 1, 1, 1, 1], + cnn_padding: SequenceType = [1, 1, 1, 1, 1], mlp_hidden_list: SequenceType = [256, 256], - mlp_activation: Optional[nn.Module] = nn.ReLU(), + mlp_activation: nn.Module = nn.ReLU(), att_heads: int = 8, att_hidden: int = 128, n_att: int = 4, n_feedforward: int = 2, feedforward_hidden: int = 256, drop_p: float = 0.5, - augment: bool = True, max_T: int = 17 ) -> None: + """ + Overview: + Initialize the MCTS procedure cloning model according to corresponding input arguments. + Arguments: + - obs_shape (:obj:`SequenceType`): Observation space shape, such as [4, 84, 84]. + - action_dim (:obj:`int`): Action space shape, such as 6. + - cnn_hidden_list (:obj:`SequenceType`): The cnn channel dims for each block, such as\ + [128, 128, 256, 256, 256]. + - cnn_activation (:obj:`nn.Module`): The activation function for cnn blocks, such as ``nn.ReLU()``. + - cnn_kernel_size (:obj:`SequenceType`): The kernel size for each cnn block, such as [3, 3, 3, 3, 3]. + - cnn_stride (:obj:`SequenceType`): The stride for each cnn block, such as [1, 1, 1, 1, 1]. + - cnn_padding (:obj:`SequenceType`): The padding for each cnn block, such as [1, 1, 1, 1, 1]. + - mlp_hidden_list (:obj:`SequenceType`): The last dim for this must match the last dim of \ + ``cnn_hidden_list``, such as [256, 256]. + - mlp_activation (:obj:`nn.Module`): The activation function for mlp layers, such as ``nn.ReLU()``. + - att_heads (:obj:`int`): The number of attention heads in transformer, such as 8. + - att_hidden (:obj:`int`): The number of attention dimension in transformer, such as 128. + - n_att (:obj:`int`): The number of attention blocks in transformer, such as 4. + - n_feedforward (:obj:`int`): The number of feedforward layers in transformer, such as 2. + - drop_p (:obj:`float`): The drop out rate of attention, such as 0.5. + - max_T (:obj:`int`): The sequence length of procedure cloning, such as 17. + """ super().__init__() - #Conv Encoder + # Conv Encoder self.embed_state = ConvEncoder( obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding ) self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation) self.cnn_hidden_list = cnn_hidden_list - self.augment = augment assert cnn_hidden_list[-1] == mlp_hidden_list[-1] layers = [] @@ -93,7 +150,7 @@ def __init__( layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) self.layernorm2 = build_normalization('LN')(feedforward_hidden) - self.transformer = Block( + self.transformer = PCTransformer( cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward ) @@ -101,8 +158,28 @@ def __init__( self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim) def forward(self, states: torch.Tensor, goals: torch.Tensor, - actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - + actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + ProcedureCloningMCTS forward computation graph, input states tensor and goals tensor, \ + calculate the predicted states and actions. + Arguments: + - states (:obj:`torch.Tensor`): The observation of current time. + - goals (:obj:`torch.Tensor`): The target observation after a period. + - actions (:obj:`torch.Tensor`): The actions executed during the period. + Returns: + - outputs (:obj:`Tuple[torch.Tensor, torch.Tensor]`): Predicted states and actions. + Examples: + >>> inputs = { \ + 'states': torch.randn(2, 3, 64, 64), \ + 'goals': torch.randn(2, 3, 64, 64), \ + 'actions': torch.randn(2, 15, 9) \ + } + >>> model = ProcedureCloningMCTS(obs_shape=(3, 64, 64), action_dim=9) + >>> goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) + >>> assert goal_preds.shape == (2, 256) + >>> assert action_preds.shape == (2, 16, 9) + """ B, T, _ = actions.shape # shape: (B, h_dim) @@ -123,7 +200,8 @@ def forward(self, states: torch.Tensor, goals: torch.Tensor, class BFSConvEncoder(nn.Module): """ - Overview: The ``BFSConvolution Encoder`` used to encode raw 2-dim observations. And output a feature map with the + Overview: + The ``BFSConvolution Encoder`` used to encode raw 3-dim observations. And output a feature map with the same height and width as input. Interfaces: ``__init__``, ``forward``. """ @@ -174,21 +252,42 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: - x (:obj:`torch.Tensor`): Env raw observation. Returns: - outputs (:obj:`torch.Tensor`): Output embedding tensor. - Shapes: - - outputs: :math:`(B, N, H, W)`, where ``N = hidden_size_list[-1]``. + Examples: + >>> model = BFSConvEncoder([3, 16, 16], [32, 32, 4], kernel_size=[3, 3, 3], stride=[1, 1, 1]\ + , padding=[1, 1, 1]) + >>> inputs = torch.randn(3, 16, 16).unsqueeze(0) + >>> outputs = model(inputs) + >>> assert outputs['logit'].shape == torch.Size([4, 16, 16]) """ return self.main(x) @MODEL_REGISTRY.register('pc_bfs') class ProcedureCloningBFS(nn.Module): + """ + Overview: + The neural network introduced in procedure cloning (PC) to process 3-dim observations.\ + Given an input, this model will perform several 3x3 convolutions and output a feature map with \ + the same height and width of input. The channel number of output will be the ``action_shape``. + Interfaces: + ``__init__``, ``forward``. + """ def __init__( self, - obs_shape: Union[int, SequenceType], - action_shape: Union[int, SequenceType], + obs_shape: SequenceType, + action_shape: int, encoder_hidden_size_list: SequenceType = [128, 128, 256, 256], ): + """ + Overview: + Init the ``BFSConvolution Encoder`` according to the provided arguments. + Arguments: + - obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``,\ + such as [4, 84, 84]. + - action_dim (:obj:`int`): Action space shape, such as 6. + - cnn_hidden_list (:obj:`SequenceType`): The cnn channel dims for each block, such as [128, 128, 256, 256]. + """ super().__init__() num_layers = len(encoder_hidden_size_list) @@ -207,6 +306,21 @@ def __init__( ) def forward(self, x: torch.Tensor) -> Dict: + """ + Overview: + The computation graph. Given a 3-dim observation, this function will return a tensor with the same\ + height and width. The channel number of output will be the ``action_shape``. + Arguments: + - x (:obj:`torch.Tensor`): The input observation tensor data. + Returns: + - outputs (:obj:`Dict`): The output dict of model's forward computation graph, \ + only contains a single key ``logit``. + Examples: + >>> model = ProcedureCloningBFS([3, 16, 16], 4) + >>> inputs = torch.randn(16, 16, 3).unsqueeze(0) + >>> outputs = model(inputs) + >>> assert outputs['logit'].shape == torch.Size([16, 16, 4]) + """ x = x.permute(0, 3, 1, 2) x = self._encoder(x) return {'logit': x.permute(0, 2, 3, 1)} diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index 5a52542879..b2bb197954 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -24,9 +24,6 @@ def test_procedure_cloning_mcts(self, obs_shape, action_dim): 'actions': torch.randn(B, T, action_dim) } model = ProcedureCloningMCTS(obs_shape=obs_shape, action_dim=action_dim) - - print(model) - goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) assert goal_preds.shape == (B, obs_embeddings) assert action_preds.shape == (B, T + 1, action_dim) From 1d35eef6e8d3ae99d24ed03e821b9e88cf78918c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sat, 7 Oct 2023 23:01:08 +0800 Subject: [PATCH 2/3] polish --- ding/model/template/bc.py | 6 +++--- ding/model/template/decision_transformer.py | 2 +- ding/model/template/language_transformer.py | 2 +- ding/model/template/procedure_cloning.py | 8 ++++---- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ding/model/template/bc.py b/ding/model/template/bc.py index 84499bc717..be74866e30 100644 --- a/ding/model/template/bc.py +++ b/ding/model/template/bc.py @@ -10,7 +10,7 @@ @MODEL_REGISTRY.register('discrete_bc') class DiscreteBC(nn.Module): - r""" + """ Overview: The DiscreteBC network. Interfaces: @@ -88,7 +88,7 @@ def __init__( ) def forward(self, x: torch.Tensor) -> Dict: - r""" + """ Overview: DiscreteBC forward computation graph, input observation tensor to predict q_value. Arguments: @@ -113,7 +113,7 @@ def forward(self, x: torch.Tensor) -> Dict: @MODEL_REGISTRY.register('continuous_bc') class ContinuousBC(nn.Module): - r""" + """ Overview: The ContinuousBC network. Interfaces: diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py index 988cedef8a..3bc2b551cc 100644 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -22,7 +22,7 @@ class MaskedCausalAttention(nn.Module): - r""" + """ Overview: The implementation of masked causal attention in decision transformer. Interfaces: diff --git a/ding/model/template/language_transformer.py b/ding/model/template/language_transformer.py index b9b52f14f5..cac2d69adf 100644 --- a/ding/model/template/language_transformer.py +++ b/ding/model/template/language_transformer.py @@ -12,7 +12,7 @@ @MODEL_REGISTRY.register('language_transformer') class LanguageTransformer(nn.Module): - r""" + """ Overview: The LanguageTransformer network. Download a pre-trained language model and add head on it. Interfaces: diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 9c22c5ee6b..692e49913d 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -78,10 +78,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @MODEL_REGISTRY.register('pc_mcts') class ProcedureCloningMCTS(nn.Module): """ - Overview: - The neural network of algorithms related to Procedure cloning (PC). - Interfaces: - ``__init__``, ``forward``. + Overview: + The neural network of algorithms related to Procedure cloning (PC). + Interfaces: + ``__init__``, ``forward``. """ def __init__( From ab445a76d02e535a61d876bb8d9bca22c14827b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 9 Oct 2023 10:50:25 +0800 Subject: [PATCH 3/3] polish doc --- ding/model/template/bc.py | 6 ++++-- ding/model/template/decision_transformer.py | 16 +++++++++------- ding/model/template/procedure_cloning.py | 7 ++++--- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/ding/model/template/bc.py b/ding/model/template/bc.py index be74866e30..ce58ca8c5f 100644 --- a/ding/model/template/bc.py +++ b/ding/model/template/bc.py @@ -44,6 +44,8 @@ def __init__( if ``None`` then default set it to ``nn.ReLU()``. - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ ``ding.torch_utils.fc_block`` for more details. + - strides (:obj:`Optional[list]`): The strides for each convolution layers, such as [2, 2, 2]. The length \ + of this argument should be the same as ``encoder_hidden_size_list``. """ super(DiscreteBC, self).__init__() # For compatibility: 1, (1, ), [4, 32, 32] @@ -197,8 +199,8 @@ def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict: >>> model = ContinuousBC(32, 6, action_space='reparameterization') >>> inputs = torch.randn(4, 32) >>> outputs = model(inputs) - >>> assert isinstance(outputs, dict) and outputs['logit'][0].shape == torch.Size([4, 6])\ - and outputs['logit'][1].shape == torch.Size([4, 6]) + >>> assert isinstance(outputs, dict) and outputs['logit'][0].shape == torch.Size([4, 6]) + >>> assert outputs['logit'][1].shape == torch.Size([4, 6]) """ if self.action_space == 'regression': x = self.actor(inputs) diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py index 3bc2b551cc..0b8f47a71d 100644 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -24,7 +24,9 @@ class MaskedCausalAttention(nn.Module): """ Overview: - The implementation of masked causal attention in decision transformer. + The implementation of masked causal attention in decision transformer. The input of this module is a sequence \ + of several tokens. For the calculated hidden embedding for the i-th token, it is only related the 0 to i-1 \ + input tokens by applying a mask to the attention map. Thus, this module is called masked-causal attention. Interfaces: ``__init__``, ``forward`` """ @@ -34,7 +36,7 @@ def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: Overview: Initialize the MaskedCausalAttention Model according to input arguments. Arguments: - - h_dim (:obj:`int`): The dimension of hidden states, such as 128. + - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. - max_T (:obj:`int`): The max context length of the attention, such as 6. - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. @@ -115,7 +117,7 @@ def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: Overview: Initialize the Block Model according to input arguments. Arguments: - - h_dim (:obj:`int`): The dimension of hidden states, such as 128. + - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. - max_T (:obj:`int`): The max context length of the attention, such as 6. - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. @@ -182,16 +184,16 @@ def __init__( Overview: Initialize the DecisionTransformer Model according to input arguments. Arguments: - - obs_shape (:obj:`Union[int, SequenceType]`): Dimension of state, such as 128, (4, 84, 84). + - obs_shape (:obj:`Union[int, SequenceType]`): Dimension of state, such as 128 or (4, 84, 84). - act_dim (:obj:`int`): The dimension of actions, such as 6. - n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3. - - h_dim (:obj:`int`): The dimension of hidden states, such as 128. + - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. - context_len (:obj:`int`): The max context length of the attention, such as 6. - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. - max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096. - state_encoder (:obj:`Optional[nn.Module]`): The encoder to pre-process the given input. If it is set to \ - None, the raw state will be pushed into the transformer. + None, the raw state will be pushed into the transformer. - continuous (:obj:`bool`): Whether the action space is continuous, defaults to be ``False``. """ super().__init__() @@ -254,7 +256,7 @@ def forward( - states (:obj:`torch.Tensor`): The sequence of states. - actions (:obj:`torch.Tensor`): The sequence of actions. - returns_to_go (:obj:`torch.Tensor`): The sequence of return-to-go. - - tar (:obj:`Optional[int]`): The targe index. + - tar (:obj:`Optional[int]`): Whether to predict action, regardless of index. Returns: - output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains three tensors, \ they are correspondingly the predicted states, predicted actions and predicted return-to-go. diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 692e49913d..4f03c8a4bf 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -31,7 +31,8 @@ def __init__( - drop_p (:obj:`float`): The drop out rate of attention, such as 0.5. - max_T (:obj:`int`): The sequence length of procedure cloning, such as 4. - n_attn (:obj:`int`): The number of attention layers, such as 4. - - feedforward_hidden (:obj:`int`): The number of feedforward layers, such as 4. + - feedforward_hidden (:obj:`int`):The dimension of feedforward layers, such as 32. + - n_feedforward (:obj:`int`): The number of feedforward layers, such as 4. """ super().__init__() self.n_att = n_att @@ -308,8 +309,8 @@ def __init__( def forward(self, x: torch.Tensor) -> Dict: """ Overview: - The computation graph. Given a 3-dim observation, this function will return a tensor with the same\ - height and width. The channel number of output will be the ``action_shape``. + The computation graph. Given a 3-dim observation, this function will return a tensor with the same \ + height and width. The channel number of output will be the ``action_shape``. Arguments: - x (:obj:`torch.Tensor`): The input observation tensor data. Returns: