Skip to content

Commit

Permalink
polish(pu): add LN and GN norm_type support in ResBlock (#660)
Browse files Browse the repository at this point in the history
* polish(pu): add LN and GN norm_type support in ResBlock

* polish(pu): polish norm_type branch in conv2d_block
  • Loading branch information
puyuan1996 authored May 11, 2023
1 parent fa521b0 commit a8f0ac9
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 32 deletions.
51 changes: 32 additions & 19 deletions ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def conv1d_block(
- activation (:obj:`nn.Module`): the optional activation function
- norm_type (:obj:`str`): type of the normalization
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 1 dim convlution layer
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 1 dim convolution layer
.. note::
Expand All @@ -116,23 +116,26 @@ def conv2d_block(
pad_type: str = 'zero',
activation: nn.Module = None,
norm_type: str = None,
num_groups_for_gn: int = 1,
bias: bool = True
) -> nn.Sequential:
r"""
Overview:
Create a 2-dim convolution layer with activation and normalization.
Arguments:
- in_channels (:obj:`int`): Number of channels in the input tensor
- out_channels (:obj:`int`): Number of channels in the output tensor
- kernel_size (:obj:`int`): Size of the convolving kernel
- stride (:obj:`int`): Stride of the convolution
- padding (:obj:`int`): Zero-padding added to both sides of the input
- dilation (:obj:`int`): Spacing between kernel elements
- groups (:obj:`int`): Number of blocked connections from input channels to output channels
- pad_type (:obj:`str`): the way to add padding, include ['zero', 'reflect', 'replicate'], default: None
- activation (:obj:`nn.Module`): the optional activation function
- norm_type (:obj:`str`): type of the normalization, default set to None, now support ['BN', 'IN', 'SyncBN']
- bias (:obj:`bool`): whether adds a learnable bias to the nn.Conv2d. default set to True
- in_channels (:obj:`int`): Number of channels in the input tensor.
- out_channels (:obj:`int`): Number of channels in the output tensor.
- kernel_size (:obj:`int`): Size of the convolving kernel.
- stride (:obj:`int`): Stride of the convolution.
- padding (:obj:`int`): Zero-padding added to both sides of the input.
- dilation (:obj:`int`): Spacing between kernel elements.
- groups (:obj:`int`): Number of blocked connections from input channels to output channels.
- pad_type (:obj:`str`): the way to add padding, include ['zero', 'reflect', 'replicate'], default: None.
- activation (:obj:`nn.Module`): the optional activation function.
- norm_type (:obj:`str`): The type of the normalization, now support ['BN', 'LN', 'IN', 'GN', 'SyncBN'],
default set to None, which means no normalization.
- num_groups_for_gn (:obj:`int`): Number of groups for GroupNorm.
- bias (:obj:`bool`): whether adds a learnable bias to the nn.Conv2d. Default set to True.
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2 dim convlution layer
Expand Down Expand Up @@ -163,7 +166,17 @@ def conv2d_block(
)
)
if norm_type is not None:
block.append(build_normalization(norm_type, dim=2)(out_channels))
if norm_type is 'LN':
# LN is implemented as GroupNorm with 1 group.
block.append(nn.GroupNorm(1, out_channels))
elif norm_type is 'GN':
block.append(nn.GroupNorm(num_groups_for_gn, out_channels))
elif norm_type in ['BN', 'IN', 'SyncBN']:
block.append(build_normalization(norm_type, dim=2)(out_channels))
else:
raise KeyError("Invalid value in norm_type: {}. The valid norm_type are "
"BN, LN, IN, GN and SyncBN.".format(norm_type))

if activation is not None:
block.append(activation)
return sequential_pack(block)
Expand All @@ -182,7 +195,7 @@ def deconv2d_block(
) -> nn.Sequential:
r"""
Overview:
Create a 2-dim transopse convlution layer with activation and normalization
Create a 2-dim transpose convolution layer with activation and normalization
Arguments:
- in_channels (:obj:`int`): Number of channels in the input tensor
- out_channels (:obj:`int`): Number of channels in the output tensor
Expand All @@ -194,7 +207,7 @@ def deconv2d_block(
- norm_type (:obj:`str`): type of the normalization
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2-dim \
transpose convlution layer
transpose convolution layer
.. note::
Expand Down Expand Up @@ -486,7 +499,7 @@ def one_hot(val: torch.LongTensor, num: int, num_first: bool = False) -> torch.F
class NearestUpsample(nn.Module):
r"""
Overview:
Upsamples the input to the given member varible scale_factor using mode nearest
Upsamples the input to the given member variable scale_factor using mode nearest
Interface:
forward
"""
Expand Down Expand Up @@ -516,7 +529,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class BilinearUpsample(nn.Module):
r"""
Overview:
Upsamples the input to the given member varible scale_factor using mode biliner
Upsamples the input to the given member variable scale_factor using mode bilinear
Interface:
forward
"""
Expand Down Expand Up @@ -601,7 +614,7 @@ def _scale_noise(self, size: Union[int, Tuple]):
def reset_noise(self):
r"""
Overview:
Reset noise settinngs in the layer.
Reset noise settings in the layer.
"""
is_cuda = self.weight_mu.is_cuda
in_noise = self._scale_noise(self.in_channels).to(torch.device("cuda" if is_cuda else "cpu"))
Expand Down Expand Up @@ -663,7 +676,7 @@ def noise_block(
- norm_type (:obj:`str`): type of the normalization
- use_dropout (:obj:`bool`) : whether to use dropout in the fully-connected block
- dropout_probability (:obj:`float`) : probability of an element to be zeroed in the dropout. Default: 0.5
- simga0 (:obj:`float`): the sigma0 is the defalut noise volumn when init NoiseLinearLayer
- simga0 (:obj:`float`): the sigma0 is the default noise volume when init NoiseLinearLayer
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block
Expand Down
8 changes: 5 additions & 3 deletions ding/torch_utils/network/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module:
Overview:
Build the corresponding normalization module
Arguments:
- norm_type (:obj:`str`): type of the normaliztion, now support ['BN', 'IN', 'SyncBN', 'AdaptiveIN']
- norm_type (:obj:`str`): type of the normaliztion, now support ['BN', 'LN', 'IN', 'SyncBN']
- dim (:obj:`int`): dimension of the normalization, when norm_type is in [BN, IN]
Returns:
- norm_func (:obj:`nn.Module`): the corresponding batch normalization function
Expand All @@ -18,17 +18,19 @@ def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module:
if dim is None:
key = norm_type
else:
if norm_type in ['BN', 'IN', 'SyncBN']:
if norm_type in ['BN', 'IN']:
key = norm_type + str(dim)
elif norm_type in ['LN']:
elif norm_type in ['LN', 'SyncBN']:
key = norm_type
else:
raise NotImplementedError("not support indicated dim when creates {}".format(norm_type))
norm_func = {
'BN1': nn.BatchNorm1d,
'BN2': nn.BatchNorm2d,
'LN': nn.LayerNorm,
'IN1': nn.InstanceNorm1d,
'IN2': nn.InstanceNorm2d,
'SyncBN': nn.SyncBatchNorm,
}
if key in norm_func.keys():
return norm_func[key]
Expand Down
6 changes: 3 additions & 3 deletions ding/torch_utils/network/res_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
- in_channels (:obj:`int`): Number of channels in the input tensor.
- activation (:obj:`nn.Module`): the optional activation function.
- norm_type (:obj:`str`): type of the normalization, default set to 'BN'(Batch Normalization), \
supports ['BN', 'IN', 'SyncBN', None].
supports ['BN', 'LN', 'IN', 'GN', 'SyncBN', None].
- res_type (:obj:`str`): type of residual block, supports ['basic', 'bottleneck', 'downsample']
- bias (:obj:`bool`): whether adds a learnable bias to the conv2d_block. default set to True.
- out_channels (:obj:`int`): Number of channels in the output tensor, default set to None,
Expand Down Expand Up @@ -101,15 +101,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class ResFCBlock(nn.Module):
r'''
r"""
Overview:
Residual Block with 2 fully connected layers.
x -> fc1 -> norm -> act -> fc2 -> norm -> act -> out
\_____________________________________/+
Interfaces:
forward
'''
"""

def __init__(self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN'):
r"""
Expand Down
15 changes: 8 additions & 7 deletions ding/torch_utils/network/tests/test_res_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ class TestResBlock:
def test_res_blcok(self):
input = torch.rand(batch_size, in_channels, 2, 3).requires_grad_(True)
for r in res_type:
model = ResBlock(in_channels, activation, norm_type, r)
output = model(input)
loss = output.mean()
loss.backward()
if r in res_type_classic:
assert output.shape == input.shape
assert isinstance(input.grad, torch.Tensor)
for norm_type in ['BN', 'LN', 'IN', 'GN', None]:
model = ResBlock(in_channels, activation, norm_type, r)
output = model(input)
loss = output.mean()
loss.backward()
if r in res_type_classic:
assert output.shape == input.shape
assert isinstance(input.grad, torch.Tensor)

def test_res_fc_block(self):
input = torch.rand(batch_size, in_channels).requires_grad_(True)
Expand Down

0 comments on commit a8f0ac9

Please sign in to comment.