Skip to content

Commit

Permalink
doc(whl): add code doc for LT,DT,PC,BC models (#734)
Browse files Browse the repository at this point in the history
* doc(whl): add code doc for LT,DT,PC,BC models

* polish

* polish doc
  • Loading branch information
kxzxvbk authored Oct 9, 2023
1 parent 92ac919 commit b7f703e
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 54 deletions.
38 changes: 29 additions & 9 deletions ding/model/template/bc.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Union, Optional, Dict, Callable, List
from typing import Union, Optional, Dict
import torch
import torch.nn as nn
from easydict import EasyDict

from ding.torch_utils import get_lstm
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \
MultiHead, RegressionHead, ReparameterizationHead


@MODEL_REGISTRY.register('discrete_bc')
class DiscreteBC(nn.Module):
"""
Overview:
The DiscreteBC network.
Interfaces:
``__init__``, ``forward``
"""

def __init__(
self,
Expand All @@ -36,9 +41,11 @@ def __init__(
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``
if ``None`` then default set it to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
``ding.torch_utils.fc_block`` for more details.
- strides (:obj:`Optional[list]`): The strides for each convolution layers, such as [2, 2, 2]. The length \
of this argument should be the same as ``encoder_hidden_size_list``.
"""
super(DiscreteBC, self).__init__()
# For compatibility: 1, (1, ), [4, 32, 32]
Expand Down Expand Up @@ -83,7 +90,7 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> Dict:
r"""
"""
Overview:
DiscreteBC forward computation graph, input observation tensor to predict q_value.
Arguments:
Expand All @@ -108,7 +115,7 @@ def forward(self, x: torch.Tensor) -> Dict:

@MODEL_REGISTRY.register('continuous_bc')
class ContinuousBC(nn.Module):
r"""
"""
Overview:
The ContinuousBC network.
Interfaces:
Expand All @@ -127,7 +134,7 @@ def __init__(
) -> None:
"""
Overview:
Initailize the ContinuousBC Model according to input arguments.
Initialize the ContinuousBC Model according to input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
Expand Down Expand Up @@ -173,14 +180,27 @@ def __init__(
)
)

def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict:
"""
Overview:
The unique execution (forward) method of ContinuousBC method.
The unique execution (forward) method of ContinuousBC.
Arguments:
- inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space.
- output (:obj:`Dict`): Output dict data, including different key-values among distinct action_space.
Examples (Regression):
>>> model = ContinuousBC(32, 6, action_space='regression')
>>> inputs = torch.randn(4, 32)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict) and outputs['action'].shape == torch.Size([4, 6])
Examples (Reparameterization):
>>> model = ContinuousBC(32, 6, action_space='reparameterization')
>>> inputs = torch.randn(4, 32)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict) and outputs['logit'][0].shape == torch.Size([4, 6])
>>> assert outputs['logit'][1].shape == torch.Size([4, 6])
"""
if self.action_space == 'regression':
x = self.actor(inputs)
Expand Down
192 changes: 170 additions & 22 deletions ding/model/template/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,34 @@
"""

import math
from typing import Union, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.utils import SequenceType


class MaskedCausalAttention(nn.Module):

def __init__(self, h_dim, max_T, n_heads, drop_p):
"""
Overview:
The implementation of masked causal attention in decision transformer. The input of this module is a sequence \
of several tokens. For the calculated hidden embedding for the i-th token, it is only related the 0 to i-1 \
input tokens by applying a mask to the attention map. Thus, this module is called masked-causal attention.
Interfaces:
``__init__``, ``forward``
"""

def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None:
"""
Overview:
Initialize the MaskedCausalAttention Model according to input arguments.
Arguments:
- h_dim (:obj:`int`): The dimension of the hidden layers, such as 128.
- max_T (:obj:`int`): The max context length of the attention, such as 6.
- n_heads (:obj:`int`): The number of heads in calculating attention, such as 8.
- drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1.
"""
super().__init__()

self.n_heads = n_heads
Expand All @@ -42,7 +62,22 @@ def __init__(self, h_dim, max_T, n_heads, drop_p):
# during backpropagation
self.register_buffer('mask', mask)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
MaskedCausalAttention forward computation graph, input a sequence tensor \
and return a tensor with the same shape.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
Returns:
- out (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input.
Examples:
>>> inputs = torch.randn(2, 4, 64)
>>> model = MaskedCausalAttention(64, 5, 4, 0.1)
>>> outputs = model(inputs)
>>> assert outputs.shape == torch.Size([2, 4, 64])
"""
B, T, C = x.shape # batch size, seq length, h_dim * n_heads

N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim
Expand Down Expand Up @@ -70,8 +105,23 @@ def forward(self, x):


class Block(nn.Module):

def __init__(self, h_dim, max_T, n_heads, drop_p):
"""
Overview:
The implementation of a transformer block in decision transformer.
Interfaces:
``__init__``, ``forward``
"""

def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None:
"""
Overview:
Initialize the Block Model according to input arguments.
Arguments:
- h_dim (:obj:`int`): The dimension of the hidden layers, such as 128.
- max_T (:obj:`int`): The max context length of the attention, such as 6.
- n_heads (:obj:`int`): The number of heads in calculating attention, such as 8.
- drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1.
"""
super().__init__()
self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p)
self.mlp = nn.Sequential(
Expand All @@ -83,7 +133,22 @@ def __init__(self, h_dim, max_T, n_heads, drop_p):
self.ln1 = nn.LayerNorm(h_dim)
self.ln2 = nn.LayerNorm(h_dim)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Forward computation graph of the decision transformer block, input a sequence tensor \
and return a tensor with the same shape.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
Returns:
- output (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input.
Examples:
>>> inputs = torch.randn(2, 4, 64)
>>> model = Block(64, 5, 4, 0.1)
>>> outputs = model(inputs)
>>> outputs.shape == torch.Size([2, 4, 64])
"""
# Attention -> LayerNorm -> MLP -> LayerNorm
x = x + self.attention(x) # residual
x = self.ln1(x)
Expand All @@ -95,20 +160,42 @@ def forward(self, x):


class DecisionTransformer(nn.Module):
"""
Overview:
The implementation of decision transformer.
Interfaces:
``__init__``, ``forward``, ``configure_optimizers``
"""

def __init__(
self,
state_dim,
act_dim,
n_blocks,
h_dim,
context_len,
n_heads,
drop_p,
max_timestep=4096,
state_encoder=None,
continuous=False
state_dim: Union[int, SequenceType],
act_dim: int,
n_blocks: int,
h_dim: int,
context_len: int,
n_heads: int,
drop_p: float,
max_timestep: int = 4096,
state_encoder: Optional[nn.Module] = None,
continuous: bool = False
):
"""
Overview:
Initialize the DecisionTransformer Model according to input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Dimension of state, such as 128 or (4, 84, 84).
- act_dim (:obj:`int`): The dimension of actions, such as 6.
- n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3.
- h_dim (:obj:`int`): The dimension of the hidden layers, such as 128.
- context_len (:obj:`int`): The max context length of the attention, such as 6.
- n_heads (:obj:`int`): The number of heads in calculating attention, such as 8.
- drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1.
- max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096.
- state_encoder (:obj:`Optional[nn.Module]`): The encoder to pre-process the given input. If it is set to \
None, the raw state will be pushed into the transformer.
- continuous (:obj:`bool`): Whether the action space is continuous, defaults to be ``False``.
"""
super().__init__()

self.state_dim = state_dim
Expand Down Expand Up @@ -152,7 +239,60 @@ def __init__(
self.embed_action = nn.Sequential(nn.Embedding(act_dim, h_dim), nn.Tanh())
self.transformer = nn.Sequential(*blocks)

def forward(self, timesteps, states, actions, returns_to_go, tar=None):
def forward(
self,
timesteps: torch.Tensor,
states: torch.Tensor,
actions: torch.Tensor,
returns_to_go: torch.Tensor,
tar: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Overview:
Forward computation graph of the decision transformer, input a sequence tensor \
and return a tensor with the same shape.
Arguments:
- timesteps (:obj:`torch.Tensor`): The timestep for input sequence.
- states (:obj:`torch.Tensor`): The sequence of states.
- actions (:obj:`torch.Tensor`): The sequence of actions.
- returns_to_go (:obj:`torch.Tensor`): The sequence of return-to-go.
- tar (:obj:`Optional[int]`): Whether to predict action, regardless of index.
Returns:
- output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains three tensors, \
they are correspondingly the predicted states, predicted actions and predicted return-to-go.
Examples:
>>> B, T = 4, 6
>>> state_dim = 3
>>> act_dim = 2
>>> DT_model = DecisionTransformer(\
state_dim=state_dim,\
act_dim=act_dim,\
n_blocks=3,\
h_dim=8,\
context_len=T,\
n_heads=2,\
drop_p=0.1,\
)
>>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T
>>> states = torch.randn([B, T, state_dim]) # B x T x state_dim
>>> actions = torch.randint(0, act_dim, [B, T, 1])
>>> action_target = torch.randint(0, act_dim, [B, T, 1])
>>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float()
>>> traj_mask = torch.ones([B, T], dtype=torch.long) # B x T
>>> actions = actions.squeeze(-1)
>>> state_preds, action_preds, return_preds = DT_model.forward(\
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go\
)
>>> assert state_preds.shape == torch.Size([B, T, state_dim])
>>> assert return_preds.shape == torch.Size([B, T, 1])
>>> assert action_preds.shape == torch.Size([B, T, act_dim])
"""
B, T = states.shape[0], states.shape[1]
if self.state_encoder is None:
time_embeddings = self.embed_timestep(timesteps)
Expand Down Expand Up @@ -217,12 +357,20 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):

return state_preds, action_preds, return_preds

def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)):
def configure_optimizers(
self, weight_decay: float, learning_rate: float, betas: Tuple[float, float] = (0.9, 0.95)
) -> torch.optim.optimizer.Optimizer:
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
Overview:
This function returns an optimizer given the input arguments. \
We are separating out all parameters of the model into two buckets: those that will experience \
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
Arguments:
- weight_decay (:obj:`float`): The weigh decay of the optimizer.
- learning_rate (:obj:`float`): The learning rate of the optimizer.
- betas (:obj:`Tuple[float, float]`): The betas for Adam optimizer.
Outputs:
- optimizer (:obj:`torch.optim.optimizer.Optimizer`): The desired optimizer.
"""

# separate out all parameters to those that will and won't experience regularizing weight decay
Expand Down
Loading

0 comments on commit b7f703e

Please sign in to comment.