Skip to content

Commit

Permalink
fix(pu): fix last_linear_layer_weight_bias_init_zero in MLP and add i…
Browse files Browse the repository at this point in the history
…ts unittest (#650)

* fix(pu): fix last_linear_layer_weight_bias_init_zero in MLP and add its unittest

* polish(pu): polish unittest of mlp

* style(pu): yapf format

* style(pu): flake8 format

* polish(pu): polish the output_activation and output_norm in MLP

* style(pu): polish the annotations in MLP, yapf format

* style(pu): flake8 style fix

* fix(pu): fix output_activation and output_norm in MLP
  • Loading branch information
puyuan1996 authored Apr 25, 2023
1 parent aefddac commit 9e7002f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 45 deletions.
62 changes: 33 additions & 29 deletions ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ def MLP(
norm_type: str = None,
use_dropout: bool = False,
dropout_probability: float = 0.5,
output_activation: nn.Module = None,
output_norm_type: str = None,
output_activation: bool = True,
output_norm: bool = True,
last_linear_layer_init_zero: bool = False
):
r"""
Expand All @@ -328,15 +328,18 @@ def MLP(
- hidden_channels (:obj:`int`): Number of channels in the hidden tensor.
- out_channels (:obj:`int`): Number of channels in the output tensor.
- layer_num (:obj:`int`): Number of layers.
- layer_fn (:obj:`Callable`): layer function.
- activation (:obj:`nn.Module`): the optional activation function.
- norm_type (:obj:`str`): type of the normalization.
- use_dropout (:obj:`bool`): whether to use dropout in the fully-connected block.
- dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5.
- output_activation (:obj:`nn.Module`): the optional activation function in the last layer.
- output_norm_type (:obj:`str`): type of the normalization in the last layer.
- last_linear_layer_init_zero (:obj:`bool`): zero initialization for the last linear layer (including w and b),
which can provide stable zero outputs in the beginning.
- layer_fn (:obj:`Callable`): Layer function.
- activation (:obj:`nn.Module`): The optional activation function.
- norm_type (:obj:`str`): The type of the normalization.
- use_dropout (:obj:`bool`): Whether to use dropout in the fully-connected block.
- dropout_probability (:obj:`float`): The probability of an element to be zeroed in the dropout. Default: 0.5.
- output_activation (:obj:`bool`): Whether to use activation in the output layer. If True,
we use the same activation as front layers. Default: True.
- output_norm (:obj:`bool`): Whether to use normalization in the output layer. If True,
we use the same normalization as front layers. Default: True.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last linear layer
(including w and b), which can provide stable zero outputs in the beginning,
usually used in the policy network in RL settings.
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block.
Expand All @@ -361,30 +364,31 @@ def MLP(
if use_dropout:
block.append(nn.Dropout(dropout_probability))

# the last layer
# The last layer
in_channels = channels[-2]
out_channels = channels[-1]
if output_activation is None and output_norm_type is None:
# the last layer use the same norm and activation as front layers
block.append(layer_fn(in_channels, out_channels))
block.append(layer_fn(in_channels, out_channels))
"""
In the final layer of a neural network, whether to use normalization and activation are typically determined
based on user specifications. These specifications depend on the problem at hand and the desired properties of
the model's output.
"""
if output_norm is True:
# The last layer uses the same norm as front layers.
if norm_type is not None:
block.append(build_normalization(norm_type, dim=1)(out_channels))
if output_activation is True:
# The last layer uses the same activation as front layers.
if activation is not None:
block.append(activation)
if use_dropout:
block.append(nn.Dropout(dropout_probability))
else:
# the last layer use the specific norm and activation
block.append(layer_fn(in_channels, out_channels))
if output_norm_type is not None:
block.append(build_normalization(output_norm_type, dim=1)(out_channels))
if output_activation is not None:
block.append(output_activation)
if use_dropout:
block.append(nn.Dropout(dropout_probability))
if last_linear_layer_init_zero:
block[-2].weight.data.fill_(0)
block[-2].bias.data.fill_(0)

if last_linear_layer_init_zero:
# Locate the last linear layer and initialize its weights and biases to 0.
for _, layer in enumerate(reversed(block)):
if isinstance(layer, nn.Linear):
nn.init.zeros_(layer.weight)
nn.init.zeros_(layer.bias)
break

return sequential_pack(block)

Expand Down
62 changes: 46 additions & 16 deletions ding/torch_utils/network/tests/test_nn_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import pytest
from ding.torch_utils import build_activation, build_normalization
import torch
from torch.testing import assert_allclose

from ding.torch_utils import build_activation
from ding.torch_utils.network.nn_module import MLP, conv1d_block, conv2d_block, fc_block, deconv2d_block, \
ChannelShuffle, one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten, \
normed_linear, normed_conv2d
Expand Down Expand Up @@ -44,20 +46,48 @@ def test_weight_init(self):
weight_init_(weight, 'xxx')

def test_mlp(self):
input = torch.rand(batch_size, in_channels).requires_grad_(True)
block = MLP(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
layer_num=2,
activation=torch.nn.ReLU(inplace=True),
norm_type='BN',
output_activation=torch.nn.Identity(),
output_norm_type=None,
last_linear_layer_init_zero=True
)
output = self.run_model(input, block)
assert output.shape == (batch_size, out_channels)
layer_num = 3
input_tensor = torch.rand(batch_size, in_channels).requires_grad_(True)

for output_activation in [True, False]:
for output_norm in [True, False]:
for activation in [torch.nn.ReLU(), torch.nn.LeakyReLU(), torch.nn.Tanh(), None]:
for norm_type in ["LN", "BN", None]:
# Test case 1: MLP without last linear layer initialized to 0.
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
activation=activation,
norm_type=norm_type,
output_activation=output_activation,
output_norm=output_norm
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

# Test case 2: MLP with last linear layer initialized to 0.
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
activation=activation,
norm_type=norm_type,
output_activation=output_activation,
output_norm=output_norm,
last_linear_layer_init_zero=True
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)
last_linear_layer = None
for layer in reversed(model):
if isinstance(layer, torch.nn.Linear):
last_linear_layer = layer
break
assert_allclose(last_linear_layer.weight, torch.zeros_like(last_linear_layer.weight))
assert_allclose(last_linear_layer.bias, torch.zeros_like(last_linear_layer.bias))

def test_conv1d_block(self):
length = 2
Expand Down

0 comments on commit 9e7002f

Please sign in to comment.