diff --git a/mmcv/cnn/utils/weight_init.py b/mmcv/cnn/utils/weight_init.py index a5b84494bb..154add8bca 100644 --- a/mmcv/cnn/utils/weight_init.py +++ b/mmcv/cnn/utils/weight_init.py @@ -376,13 +376,26 @@ def _initialize_override(module, override, cfg): override = [override] if isinstance(override, dict) else override for override_ in override: - if 'type' not in override_.keys(): - override_.update(cfg) - name = override_.pop('name', None) + + cp_override = copy.deepcopy(override_) + name = cp_override.pop('name', None) + if name is None: + raise ValueError('`override` must contain the key "name",' + f'but got {cp_override}') + # if override only has name kay, it means use args in init_cfg + if not cp_override: + cp_override.update(cfg) + # if override has name key and other args except type key, it will + # raise error + elif 'type' not in cp_override.keys(): + raise ValueError( + f'`override` need "type" key, but got {cp_override}') + if hasattr(module, name): - _initialize(getattr(module, name), override_, wholemodule=True) + _initialize(getattr(module, name), cp_override, wholemodule=True) else: - raise RuntimeError(f'module did not have attribute {name}') + raise RuntimeError(f'module did not have attribute {name}, ' + f'but init_cfg is {cp_override}.') def initialize(module, init_cfg): @@ -394,10 +407,9 @@ def initialize(module, init_cfg): define initializer. OpenMMLab has implemented 6 initializers including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, ``Kaiming``, and ``Pretrained``. - Example: >>> module = nn.Linear(2, 3, bias=True) - >>> init_cfg = dict(type='Constant', val =1 , bias =2) + >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2) >>> initialize(module, init_cfg) >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2)) @@ -407,11 +419,7 @@ def initialize(module, init_cfg): dict(type='Constant', layer='Linear', val=2)] >>> initialize(module, init_cfg) - >>> # Omitting ``'layer'`` initialize module with same configuration - >>> init_cfg = dict(type='Constant', val=1, bias=2) - >>> initialize(module, init_cfg) - - >>> # define key``'override'`` to initialize some specific override in + >>> # define key``'override'`` to initialize some specific part in >>> # module >>> class FooNet(nn.Module): >>> def __init__(self): @@ -420,7 +428,7 @@ def initialize(module, init_cfg): >>> self.reg = nn.Conv2d(16, 10, 3) >>> self.cls = nn.Conv2d(16, 5, 3) >>> model = FooNet() - >>> init_cfg = dict(type='Constant', val=1, bias=2, + >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d', >>> override=dict(type='Constant', name='reg', val=3, bias=4)) >>> initialize(model, init_cfg) diff --git a/tests/test_cnn/test_weight_init.py b/tests/test_cnn/test_weight_init.py index 10870162b2..343079c45e 100644 --- a/tests/test_cnn/test_weight_init.py +++ b/tests/test_cnn/test_weight_init.py @@ -266,13 +266,17 @@ def test_initialize(): model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) foonet = FooModule() + # test layer key init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2) initialize(model, init_cfg) assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.)) + assert init_cfg == dict( + type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2) + # test init_cfg with list type init_cfg = [ dict(type='Constant', layer='Conv2d', val=1, bias=2), dict(type='Constant', layer='Linear', val=3, bias=4) @@ -282,7 +286,12 @@ def test_initialize(): assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.)) + assert init_cfg == [ + dict(type='Constant', layer='Conv2d', val=1, bias=2), + dict(type='Constant', layer='Linear', val=3, bias=4) + ] + # test layer key and override key init_cfg = dict( type='Constant', val=1, @@ -302,6 +311,31 @@ def test_initialize(): torch.full(foonet.conv2d_2.weight.shape, 3.)) assert torch.equal(foonet.conv2d_2.bias, torch.full(foonet.conv2d_2.bias.shape, 4.)) + assert init_cfg == dict( + type='Constant', + val=1, + bias=2, + layer=['Conv2d', 'Linear'], + override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) + + # test override key + init_cfg = dict( + type='Constant', val=5, bias=6, override=dict(name='conv2d_2')) + initialize(foonet, init_cfg) + assert not torch.equal(foonet.linear.weight, + torch.full(foonet.linear.weight.shape, 5.)) + assert not torch.equal(foonet.linear.bias, + torch.full(foonet.linear.bias.shape, 6.)) + assert not torch.equal(foonet.conv2d.weight, + torch.full(foonet.conv2d.weight.shape, 5.)) + assert not torch.equal(foonet.conv2d.bias, + torch.full(foonet.conv2d.bias.shape, 6.)) + assert torch.equal(foonet.conv2d_2.weight, + torch.full(foonet.conv2d_2.weight.shape, 5.)) + assert torch.equal(foonet.conv2d_2.bias, + torch.full(foonet.conv2d_2.bias.shape, 6.)) + assert init_cfg == dict( + type='Constant', val=5, bias=6, override=dict(name='conv2d_2')) init_cfg = dict( type='Pretrained', @@ -325,6 +359,11 @@ def test_initialize(): torch.full(foonet.conv2d_2.weight.shape, 3.)) assert torch.equal(foonet.conv2d_2.bias, torch.full(foonet.conv2d_2.bias.shape, 4.)) + assert init_cfg == dict( + type='Pretrained', + checkpoint='modelA.pth', + override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) + # test init_cfg type with pytest.raises(TypeError): init_cfg = 'init_cfg' @@ -362,3 +401,21 @@ def test_initialize(): dict(type='Constant', name='conv2d_3', val=5, bias=6) ]) initialize(foonet, init_cfg) + + # test override with args except type key + with pytest.raises(ValueError): + init_cfg = dict( + type='Constant', + val=1, + bias=2, + override=dict(name='conv2d_2', val=3, bias=4)) + initialize(foonet, init_cfg) + + # test override without name + with pytest.raises(ValueError): + init_cfg = dict( + type='Constant', + val=1, + bias=2, + override=dict(type='Constant', val=3, bias=4)) + initialize(foonet, init_cfg)