Skip to content

Commit

Permalink
[Fix] Revise override in init_cfg (#930)
Browse files Browse the repository at this point in the history
* [Fix] Config deep copy in initialize_override

* add asserts&comments

* add test

* test org init_cfg

* test override without name

* typo
  • Loading branch information
MeowZheng authored Apr 12, 2021
1 parent 375605f commit 2fadb1a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 13 deletions.
34 changes: 21 additions & 13 deletions mmcv/cnn/utils/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,13 +376,26 @@ def _initialize_override(module, override, cfg):
override = [override] if isinstance(override, dict) else override

for override_ in override:
if 'type' not in override_.keys():
override_.update(cfg)
name = override_.pop('name', None)

cp_override = copy.deepcopy(override_)
name = cp_override.pop('name', None)
if name is None:
raise ValueError('`override` must contain the key "name",'
f'but got {cp_override}')
# if override only has name kay, it means use args in init_cfg
if not cp_override:
cp_override.update(cfg)
# if override has name key and other args except type key, it will
# raise error
elif 'type' not in cp_override.keys():
raise ValueError(
f'`override` need "type" key, but got {cp_override}')

if hasattr(module, name):
_initialize(getattr(module, name), override_, wholemodule=True)
_initialize(getattr(module, name), cp_override, wholemodule=True)
else:
raise RuntimeError(f'module did not have attribute {name}')
raise RuntimeError(f'module did not have attribute {name}, '
f'but init_cfg is {cp_override}.')


def initialize(module, init_cfg):
Expand All @@ -394,10 +407,9 @@ def initialize(module, init_cfg):
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, and ``Pretrained``.
Example:
>>> module = nn.Linear(2, 3, bias=True)
>>> init_cfg = dict(type='Constant', val =1 , bias =2)
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
>>> initialize(module, init_cfg)
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
Expand All @@ -407,11 +419,7 @@ def initialize(module, init_cfg):
dict(type='Constant', layer='Linear', val=2)]
>>> initialize(module, init_cfg)
>>> # Omitting ``'layer'`` initialize module with same configuration
>>> init_cfg = dict(type='Constant', val=1, bias=2)
>>> initialize(module, init_cfg)
>>> # define key``'override'`` to initialize some specific override in
>>> # define key``'override'`` to initialize some specific part in
>>> # module
>>> class FooNet(nn.Module):
>>> def __init__(self):
Expand All @@ -420,7 +428,7 @@ def initialize(module, init_cfg):
>>> self.reg = nn.Conv2d(16, 10, 3)
>>> self.cls = nn.Conv2d(16, 5, 3)
>>> model = FooNet()
>>> init_cfg = dict(type='Constant', val=1, bias=2,
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
>>> initialize(model, init_cfg)
Expand Down
57 changes: 57 additions & 0 deletions tests/test_cnn/test_weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,17 @@ def test_initialize():
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
foonet = FooModule()

# test layer key
init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
assert init_cfg == dict(
type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)

# test init_cfg with list type
init_cfg = [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
Expand All @@ -282,7 +286,12 @@ def test_initialize():
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.))
assert init_cfg == [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]

# test layer key and override key
init_cfg = dict(
type='Constant',
val=1,
Expand All @@ -302,6 +311,31 @@ def test_initialize():
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))

# test override key
init_cfg = dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
initialize(foonet, init_cfg)
assert not torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 5.))
assert not torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 6.))
assert not torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 5.))
assert not torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 6.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 5.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 6.))
assert init_cfg == dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))

init_cfg = dict(
type='Pretrained',
Expand All @@ -325,6 +359,11 @@ def test_initialize():
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))

# test init_cfg type
with pytest.raises(TypeError):
init_cfg = 'init_cfg'
Expand Down Expand Up @@ -362,3 +401,21 @@ def test_initialize():
dict(type='Constant', name='conv2d_3', val=5, bias=6)
])
initialize(foonet, init_cfg)

# test override with args except type key
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)

# test override without name
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(type='Constant', val=3, bias=4))
initialize(foonet, init_cfg)

0 comments on commit 2fadb1a

Please sign in to comment.