Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

doc(whl): add code doc for LT,DT,PC,BC models #734

Merged
merged 3 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions ding/model/template/bc.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Union, Optional, Dict, Callable, List
from typing import Union, Optional, Dict
import torch
import torch.nn as nn
from easydict import EasyDict

from ding.torch_utils import get_lstm
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \
MultiHead, RegressionHead, ReparameterizationHead


@MODEL_REGISTRY.register('discrete_bc')
class DiscreteBC(nn.Module):
r"""
Overview:
The DiscreteBC network.
Interfaces:
``__init__``, ``forward``
"""

def __init__(
self,
Expand All @@ -36,7 +41,7 @@ def __init__(
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``
if ``None`` then default set it to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
``ding.torch_utils.fc_block`` for more details.
"""
Expand Down Expand Up @@ -127,7 +132,7 @@ def __init__(
) -> None:
"""
Overview:
Initailize the ContinuousBC Model according to input arguments.
Initialize the ContinuousBC Model according to input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
Expand Down Expand Up @@ -173,14 +178,27 @@ def __init__(
)
)

def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict:
"""
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
Overview:
The unique execution (forward) method of ContinuousBC method.
The unique execution (forward) method of ContinuousBC.
Arguments:
- inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space.
- output (:obj:`Dict`): Output dict data, including different key-values among distinct action_space.

Examples (Regression):
>>> model = ContinuousBC(32, 6, action_space='regression')
>>> inputs = torch.randn(4, 32)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict) and outputs['action'].shape == torch.Size([4, 6])

Examples (Reparameterization):
>>> model = ContinuousBC(32, 6, action_space='reparameterization')
>>> inputs = torch.randn(4, 32)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict) and outputs['logit'][0].shape == torch.Size([4, 6])\
and outputs['logit'][1].shape == torch.Size([4, 6])
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.action_space == 'regression':
x = self.actor(inputs)
Expand Down
190 changes: 168 additions & 22 deletions ding/model/template/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,32 @@
"""

import math
from typing import Union, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.utils import SequenceType


class MaskedCausalAttention(nn.Module):

def __init__(self, h_dim, max_T, n_heads, drop_p):
r"""
Overview:
The implementation of masked causal attention in decision transformer.
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
Interfaces:
``__init__``, ``forward``
"""

def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None:
"""
Overview:
Initialize the MaskedCausalAttention Model according to input arguments.
Arguments:
- h_dim (:obj:`int`): The dimension of hidden states, such as 128.
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
- max_T (:obj:`int`): The max context length of the attention, such as 6.
- n_heads (:obj:`int`): The number of heads in calculating attention, such as 8.
- drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1.
"""
super().__init__()

self.n_heads = n_heads
Expand All @@ -42,7 +60,22 @@ def __init__(self, h_dim, max_T, n_heads, drop_p):
# during backpropagation
self.register_buffer('mask', mask)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
MaskedCausalAttention forward computation graph, input a sequence tensor \
and return a tensor with the same shape.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
Returns:
- out (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input.

Examples:
>>> inputs = torch.randn(2, 4, 64)
>>> model = MaskedCausalAttention(64, 5, 4, 0.1)
>>> outputs = model(inputs)
>>> assert outputs.shape == torch.Size([2, 4, 64])
"""
B, T, C = x.shape # batch size, seq length, h_dim * n_heads

N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim
Expand Down Expand Up @@ -70,8 +103,23 @@ def forward(self, x):


class Block(nn.Module):

def __init__(self, h_dim, max_T, n_heads, drop_p):
"""
Overview:
The implementation of a transformer block in decision transformer.
Interfaces:
``__init__``, ``forward``
"""

def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None:
"""
Overview:
Initialize the Block Model according to input arguments.
Arguments:
- h_dim (:obj:`int`): The dimension of hidden states, such as 128.
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
- max_T (:obj:`int`): The max context length of the attention, such as 6.
- n_heads (:obj:`int`): The number of heads in calculating attention, such as 8.
- drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1.
"""
super().__init__()
self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p)
self.mlp = nn.Sequential(
Expand All @@ -83,7 +131,22 @@ def __init__(self, h_dim, max_T, n_heads, drop_p):
self.ln1 = nn.LayerNorm(h_dim)
self.ln2 = nn.LayerNorm(h_dim)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Forward computation graph of the decision transformer block, input a sequence tensor \
and return a tensor with the same shape.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
Returns:
- output (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input.

Examples:
>>> inputs = torch.randn(2, 4, 64)
>>> model = Block(64, 5, 4, 0.1)
>>> outputs = model(inputs)
>>> outputs.shape == torch.Size([2, 4, 64])
"""
# Attention -> LayerNorm -> MLP -> LayerNorm
x = x + self.attention(x) # residual
x = self.ln1(x)
Expand All @@ -95,20 +158,42 @@ def forward(self, x):


class DecisionTransformer(nn.Module):
"""
Overview:
The implementation of decision transformer.
Interfaces:
``__init__``, ``forward``, ``configure_optimizers``
"""

def __init__(
self,
state_dim,
act_dim,
n_blocks,
h_dim,
context_len,
n_heads,
drop_p,
max_timestep=4096,
state_encoder=None,
continuous=False
state_dim: Union[int, SequenceType],
act_dim: int,
n_blocks: int,
h_dim: int,
context_len: int,
n_heads: int,
drop_p: float,
max_timestep: int = 4096,
state_encoder: Optional[nn.Module] = None,
continuous: bool = False
):
"""
Overview:
Initialize the DecisionTransformer Model according to input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Dimension of state, such as 128, (4, 84, 84).
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
- act_dim (:obj:`int`): The dimension of actions, such as 6.
- n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3.
- h_dim (:obj:`int`): The dimension of hidden states, such as 128.
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
- context_len (:obj:`int`): The max context length of the attention, such as 6.
- n_heads (:obj:`int`): The number of heads in calculating attention, such as 8.
- drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1.
- max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096.
- state_encoder (:obj:`Optional[nn.Module]`): The encoder to pre-process the given input. If it is set to \
None, the raw state will be pushed into the transformer.
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
- continuous (:obj:`bool`): Whether the action space is continuous, defaults to be ``False``.
"""
super().__init__()

self.state_dim = state_dim
Expand Down Expand Up @@ -152,7 +237,60 @@ def __init__(
self.embed_action = nn.Sequential(nn.Embedding(act_dim, h_dim), nn.Tanh())
self.transformer = nn.Sequential(*blocks)

def forward(self, timesteps, states, actions, returns_to_go, tar=None):
def forward(
self,
timesteps: torch.Tensor,
states: torch.Tensor,
actions: torch.Tensor,
returns_to_go: torch.Tensor,
tar: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Overview:
Forward computation graph of the decision transformer, input a sequence tensor \
and return a tensor with the same shape.
Arguments:
- timesteps (:obj:`torch.Tensor`): The timestep for input sequence.
- states (:obj:`torch.Tensor`): The sequence of states.
- actions (:obj:`torch.Tensor`): The sequence of actions.
- returns_to_go (:obj:`torch.Tensor`): The sequence of return-to-go.
- tar (:obj:`Optional[int]`): The targe index.
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
Returns:
- output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains three tensors, \
they are correspondingly the predicted states, predicted actions and predicted return-to-go.

Examples:
>>> B, T = 4, 6
>>> state_dim = 3
>>> act_dim = 2
>>> DT_model = DecisionTransformer(\
state_dim=state_dim,\
act_dim=act_dim,\
n_blocks=3,\
h_dim=8,\
context_len=T,\
n_heads=2,\
drop_p=0.1,\
)

>>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T
>>> states = torch.randn([B, T, state_dim]) # B x T x state_dim

>>> actions = torch.randint(0, act_dim, [B, T, 1])
>>> action_target = torch.randint(0, act_dim, [B, T, 1])
>>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float()

>>> traj_mask = torch.ones([B, T], dtype=torch.long) # B x T
>>> actions = actions.squeeze(-1)

>>> state_preds, action_preds, return_preds = DT_model.forward(\
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go\
)

>>> assert state_preds.shape == torch.Size([B, T, state_dim])
>>> assert return_preds.shape == torch.Size([B, T, 1])
>>> assert action_preds.shape == torch.Size([B, T, act_dim])
"""
B, T = states.shape[0], states.shape[1]
if self.state_encoder is None:
time_embeddings = self.embed_timestep(timesteps)
Expand Down Expand Up @@ -217,12 +355,20 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):

return state_preds, action_preds, return_preds

def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)):
def configure_optimizers(
self, weight_decay: float, learning_rate: float, betas: Tuple[float, float] = (0.9, 0.95)
) -> torch.optim.optimizer.Optimizer:
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
Overview:
This function returns an optimizer given the input arguments. \
We are separating out all parameters of the model into two buckets: those that will experience \
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
Arguments:
- weight_decay (:obj:`float`): The weigh decay of the optimizer.
- learning_rate (:obj:`float`): The learning rate of the optimizer.
- betas (:obj:`Tuple[float, float]`): The betas for Adam optimizer.
Outputs:
- optimizer (:obj:`torch.optim.optimizer.Optimizer`): The desired optimizer.
"""

# separate out all parameters to those that will and won't experience regularizing weight decay
Expand Down
Loading
Loading