Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] Update weight initialization in cnn.md #912

Merged
merged 15 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 276 additions & 14 deletions docs/cnn.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,285 @@ conv = ConvModule(

### Weight initialization

We wrap some initialization methods which accept a module as argument.
> code is available at [mmcv/cnn/utils/weight_init.py](../mmcv/cnn/utils/weight_init.py)
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

- `constant_init`
- `xavier_init`
- `normal_init`
- `uniform_init`
- `kaiming_init`
- `caffe2_xavier_init`
- `bias_init_with_prob`
During training, a proper initialization strategy is beneficial to speed the
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
training or obtain a higher performance. In MMCV, we provide some commonly used
methods for initializing modules like `nn.Conv2d`. Of course, we also provide a
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
high-level APIs for initializing the entire model containing one or more
modules.

Examples:
#### **Initialization of module**
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

```python
conv1 = nn.Conv2d(3, 3, 1)
normal_init(conv1, std=0.01, bias=0)
xavier_init(conv1, distribution='uniform')
```
Initializaing modules, such as `nn.Conv2d`, `nn.Linear` and so on.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

We provide the following initialization methods.

- constant_init

Initialize module parameters with constant values.

```python
>>> import torch.nn as nn
>>> from mmcv.cnn.utils.weight_init import constant_init
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
>>> conv1 = nn.Conv2d(3, 3, 1)
>>> # constant_init(module, val, bias=0)
>>> constant_init(conv1, 1, 0)
```

- xavier_init

Initialize module parameters with values according to the method
described in [Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010)](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)

```python
>>> import torch.nn as nn
>>> from mmcv.cnn.utils.weight_init import xavier_init
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
>>> conv1 = nn.Conv2d(3, 3, 1)
>>> # xavier_init(module, gain=1, bias=0, distribution='normal')
>>> xavier_init(conv1, distribution='normal')
```

- normal_init

Initialize module parameters with the values drawn from a normal distribution.

```python
>>> import torch.nn as nn
>>> from mmcv.cnn.utils.weight_init import normal_init
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
>>> conv1 = nn.Conv2d(3, 3, 1)
>>> # normal_init(module, mean=0, std=1, bias=0)
>>> normal_init(conv1, std=0.01, bias=0)
```

- uniform_init

Initialize module parameters with values drawn from a uniform distribution.

```python
>>> import torch.nn as nn
>>> from mmcv.cnn.utils.weight_init import uniform_init
>>> conv1 = nn.Conv2d(3, 3, 1)
>>> # uniform_init(module, a=0, b=1, bias=0)
>>> uniform_init(conv1, a=0, b=1)
```

- kaiming_init

Initialize module paramters with the valuse according to the method
described in [Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification - He, K. et al. (2015)](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf)

```python
>>> import torch.nn as nn
>>> from mmcv.cnn.utils.weight_init import kaiming_init
>>> conv1 = nn.Conv2d(3, 3, 1)
>>> # kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0, distribution='normal')
>>> kaiming_init(conv1)
```

- caffe2_xavier_init
Corresponds to `kaiming_uniform_` in PyTorch.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

```python
>>> import torch.nn as nn
>>> from mmcv.cnn.utils.weight_init import caffe2_xavier_init
>>> conv1 = nn.Conv2d(3, 3, 1)
>>> # caffe2_xavier_init(module, bias=0)
>>> caffe2_xavier_init(conv1)
```

- bias_init_with_prob

Initialize conv/fc bias value according to given probability proposed in [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf).
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

```python
>>> from mmcv.cnn.utils.weight_init import bias_init_with_prob
>>> # bias_init_with_prob is proposed in Focal Loss
>>> bias = bias_init_with_prob(0.01)
>>> bias
-4.59511985013459
```

#### **Initialization of model**
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

On the basis of the initialization methods, we define the corresponding initialization classes and register them to `INITIALIZERS`, so we can
use the configuration to initialize the model.

We provide the following initialization classes.

- BaseInit
- BaseInit
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
- XavierInit
- NormalInit
- UniformInit
- KaimingInit
- Caffe2XavierInit
- PretrainedInit

Before we go deeper into the usage of `initialize`, briefly introducing the
design principle of it is helpful.

- If we don't define `layer` key or `override` key, it will not initialize anything.
- If we define `override` but don't define `layer`, it will initialize parameters with the attribute name in `override`.
- If we only define `layer`, it just initialize the layer in `layer` key.
- If we define `override` and `layer`, `override` has higher priority and will override initialization mechanism.

Now, it is time to introduce the usage of `initialize` in detail.

- Initialize whole module with the same configuration

Define `layer` for initializing layer with same configuration.

```python
import torch.nn as nn
from mmcv.cnn.utils.weight_init import initialize

class FooNet(nn.Module):
def __init__(self):
super().__init__()
self.feat = nn.Conv1d(3, 1, 3)
self.reg = nn.Conv2d(3, 3, 3)
self.cls = nn.Linear(1, 2)

model = FooNet()
init_cfg = dict(type='Constant', layer=['Conv1d', 'Conv2d', 'Linear'], val=1)
# initialize whole module with same configuration
initialize(model, init_cfg)
```
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

- Initialize specific layer with different configurations

Define `layer` for initializing layer with different configurations.

```python
import torch.nn as nn
from mmcv.cnn.utils.weight_init import initialize

class FooNet(nn.Module):
def __init__(self):
super().__init__()
self.feat = nn.Conv1d(3, 1, 3)
self.reg = nn.Conv2d(3, 3, 3)
self.cls = nn.Linear(1,2)

model = FooNet()
init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
dict(type='Constant', layer='Conv2d', val=2),
dict(type='Constant', layer='Linear', val=3)]
# nn.Conv1d will be initialized with dict(type='Constant', val=1)
# nn.Conv2d will be initialized with dict(type='Constant', val=2)
# nn.Linear will be initialized with dict(type='Constant', val=3)
initialize(model, init_cfg)
```

- Initialize module with the attribute name

Define `override` for initializing module with the attribute name.

```python
import torch.nn as nn
from mmcv.cnn.utils.weight_init import initialize

class FooNet(nn.Module):
def __init__(self):
super().__init__()
self.feat = nn.Conv1d(3, 1, 3)
self.reg = nn.Conv2d(3, 3, 3)
self.cls = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))

model = FooNet()
init_cfg = dict(type='Constant', val=1, bias=2, layer=['Conv1d','Conv2d'],
override=dict(type='Constant', name='reg', val=3, bias=4))
# self.feat and self.cls will be initialized with dict(type='Constant', val=1, bias=2)
# The module called 'reg' will be initialized with dict(type='Constant', val=3, bias=4)
initialize(model, init_cfg)
```

- Initialize weights with the pretrained model

```python
import torch.nn as nn
import torchvision.models as models
from mmcv.cnn.utils.weight_init import initialize

# initialize weights with the whole model
model = models.resnet50()
init_cfg = dict(type='Pretrained',
checkpoint='torchvision://resnet50')
initialize(model, init_cfg)

# initialize weights of a sub-module with the specific part of a pretrained model by using 'prefix'
model = models.resnet50()
url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
'retinanet_r50_fpn_1x_coco/'\
'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
init_cfg = dict(type='Pretrained',
checkpoint=url, prefix='backbone.')
initialize(model, init_cfg)
```

- Initialize models inherited from BaseModule, Sequential, ModuleList

`BaseModule` is inherited from `torch.nn.Module`, and the only different between them is that `BaseModule` implements `init_weight`.

`Sequential` is inhertied from `BaseModule` and `torch.nn.Sequential`.

`ModuleList` is inhertied from `BaseModule` and `torch.nn.ModuleList`.

```python
import torch.nn as nn
from mmcv.runner.base_module import BaseModule, Sequential, ModuleList
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

class FooConv1d(BaseModule):

def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1d = nn.Conv1d(4, 1, 4)

def forward(self, x):
return self.conv1d(x)

class FooConv2d(BaseModule):

def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv2d = nn.Conv2d(3, 1, 3)

def forward(self, x):
return self.conv2d(x)

# BaseModule
init_cfg = dict(type='Constant', layer='Conv1d', val=0., bias=1.)
model = FooConv1d(init_cfg)
model.init_weight()

# Sequential
init_cfg1 = dict(type='Constant', layer='Conv1d', val=0., bias=1.)
init_cfg2 = dict(type='Constant', layer='Conv2d', val=2., bias=3.)
model1 = FooConv1d(init_cfg1)
model2 = FooConv2d(init_cfg2)
seq_model = Sequential(model1, model2)
seq_model.init_weight()
# inner init_cfg has highter priority
model1 = FooConv1d(init_cfg1)
model2 = FooConv2d(init_cfg2)
init_cfg = dict(type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)
seq_model = Sequential(model1, model2, init_cfg=init_cfg)
seq_model.init_weight()
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

# ModuleList
model1 = FooConv1d(init_cfg1)
model2 = FooConv2d(init_cfg2)
modellist = ModuleList([model1, model2])
modellist.init_weight()
# inner init_cfg has highter priority
init_cfg = dict(type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)
model1 = FooConv1d(init_cfg1)
model2 = FooConv2d(init_cfg2)
modellist = ModuleList(model1, model2, init_cfg=init_cfg)
modellist.init_weight()
```

### Model Zoo

Expand Down
6 changes: 3 additions & 3 deletions mmcv/cnn/utils/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def caffe2_xavier_init(module, bias=0):


def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to giving probability."""
"""initialize conv/fc bias value according to given probability."""
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init

Expand Down Expand Up @@ -426,11 +426,11 @@ def initialize(module, init_cfg):

>>> model = ResNet(depth=50)
>>> # Initialize weights with the pretrained model.
>>> init_cfg = dict(type='PretrainedInit',
>>> init_cfg = dict(type='Pretrained',
checkpoint='torchvision://resnet50')
>>> initialize(model, init_cfg)

>>> # Intialize weights of a sub-module with the specific part of
>>> # Initialize weights of a sub-module with the specific part of
>>> # a pretrained model by using "prefix".
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
>>> 'retinanet_r50_fpn_1x_coco/'\
Expand Down