Skip to content

Commit

Permalink
[Refactor] Move build_arch_param from DiffMutableModule to `DiffM…
Browse files Browse the repository at this point in the history
…oduleMutator` (#221)

* move build_arch_param from mutable to mutator

* fix UT of diff mutable and mutator

* modify based on shiguang's comments

* remove mutator from the unittest of mutable
  • Loading branch information
pprp authored Aug 10, 2022
1 parent e4305f3 commit 696191e
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 28 deletions.
8 changes: 4 additions & 4 deletions mmrazor/models/mutables/mutable_module/diff_mutable_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ def forward(self,
else:
return self.forward_arch_param(x, arch_param=arch_param)

def build_arch_param(self) -> nn.Parameter:
"""Build learnable architecture parameters."""
return nn.Parameter(torch.randn(self.num_choices) * 1e-3)

def compute_arch_probs(self, arch_param: nn.Parameter) -> Tensor:
"""compute chosen probs according to architecture params."""
return F.softmax(arch_param, -1)
Expand Down Expand Up @@ -232,9 +228,11 @@ def fix_chosen(self, chosen: Union[str, List[str]]) -> None:
self.is_fixed = True

def sample_choice(self, arch_param):
"""Sample choice based on arch_parameters."""
return self.choices[torch.argmax(arch_param).item()]

def dump_chosen(self):
"""Dump current choice."""
assert self.current_choice is not None
return self.current_choice

Expand Down Expand Up @@ -406,10 +404,12 @@ def choices(self) -> List[CHOSEN_TYPE]:
return list(self._candidates.keys())

def dump_chosen(self):
"""dump current choice."""
assert self.current_choice is not None
return self.current_choice

def sample_choice(self, arch_param):
"""sample choice based on `arch_param`."""
sort_idx = torch.argsort(-arch_param).cpu().numpy().tolist()
choice_idx = sort_idx[:self.num_chosen]
choice = [self.choices[i] for i in choice_idx]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Optional

import torch
import torch.nn as nn

from mmrazor.registry import MODELS
Expand Down Expand Up @@ -28,6 +29,10 @@ def __init__(self,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(custom_group=custom_group, init_cfg=init_cfg)

def build_arch_param(self, num_choices) -> nn.Parameter:
"""Build learnable architecture parameters."""
return nn.Parameter(torch.randn(num_choices) * 1e-3)

def prepare_from_supernet(self, supernet: nn.Module) -> None:
"""Inherit from ``BaseMutator``'s, generate `arch_params` in DARTS.
Expand All @@ -53,7 +58,7 @@ def build_arch_params(self):
arch_params = nn.ParameterDict()

for group_id, modules in self.search_groups.items():
group_arch_param = modules[0].build_arch_param()
group_arch_param = self.build_arch_param(modules[0].num_choices)
arch_params[str(group_id)] = group_arch_param

return arch_params
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_algorithms/test_autoslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
import torch.distributed as dist
from mmcls.data import ClsDataSample
from mmcls.structures import ClsDataSample
from mmengine.optim import build_optim_wrapper

from mmrazor import digit_version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
import torch.distributed as dist
from mmcls.data import ClsDataSample
from mmcls.structures import ClsDataSample
from mmcv import fileio
from mmengine.optim import build_optim_wrapper

Expand Down
16 changes: 10 additions & 6 deletions tests/test_models/test_mutables/test_diffchoiceroute.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@ def test_forward_arch_param(self):

# test with_arch_param = True
diffchoiceroute = MODELS.build(diff_choice_route_cfg)

arch_param = diffchoiceroute.build_arch_param()
assert len(arch_param) == 5
arch_param = nn.Parameter(torch.randn(len(edges_dict)))

x = [torch.randn(4, 32, 64, 64) for _ in range(5)]

output = diffchoiceroute.forward_arch_param(x=x, arch_param=arch_param)
assert output is not None

Expand All @@ -43,14 +40,21 @@ def test_forward_arch_param(self):
new_diff_choice_route_cfg['with_arch_param'] = False

new_diff_choice_route = MODELS.build(new_diff_choice_route_cfg)

arch_param = new_diff_choice_route.build_arch_param()
arch_param = nn.Parameter(torch.randn(len(edges_dict)))
output = new_diff_choice_route.forward_arch_param(
x=x, arch_param=arch_param)
assert output is not None

new_diff_choice_route.fix_chosen(chosen=['first_edge'])

# test sample choice
arch_param = nn.Parameter(torch.randn(len(edges_dict)))
new_diff_choice_route.sample_choice(arch_param)

# test dump_chosen
with pytest.raises(AssertionError):
new_diff_choice_route.dump_chosen()

def test_forward_fixed(self):
edges_dict = nn.ModuleDict({
'first_edge': nn.Conv2d(32, 32, 3, 1, 1),
Expand Down
9 changes: 6 additions & 3 deletions tests/test_models/test_mutables/test_diffop.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,14 @@ def test_forward_arch_param(self):
op = MODELS.build(op_cfg)
input = torch.randn(4, 32, 64, 64)

arch_param = op.build_arch_param()
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
output = op.forward_arch_param(input, arch_param=arch_param)
assert output is not None

output = op.forward_arch_param(input, arch_param=None)
assert output is not None

# test when some element of arch_param is 0
arch_param = op.build_arch_param()
arch_param = nn.Parameter(torch.ones(op.num_choices))
output = op.forward_arch_param(input, arch_param=arch_param)
assert output is not None
Expand Down Expand Up @@ -107,11 +106,15 @@ def test_forward(self):
input = torch.randn(4, 32, 64, 64)

# test set_forward_args
arch_param = op.build_arch_param()
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
op.set_forward_args(arch_param=arch_param)
output = op.forward(input)
assert output is not None

# test dump_chosen
with pytest.raises(AssertionError):
op.dump_chosen()

# test forward when is_fixed is True
op.fix_chosen('torch_conv2d_7x7')
output = op.forward(input)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/test_mutables/test_gumbelchoiceroute.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_forward_arch_param(self):
# test with_arch_param = True
GumbelChoiceRoute = MODELS.build(gumbel_choice_route_cfg)

arch_param = GumbelChoiceRoute.build_arch_param()
arch_param = nn.Parameter(torch.randn(len(edges_dict)))
assert len(arch_param) == 5
GumbelChoiceRoute.set_temperature(1.0)

Expand All @@ -49,7 +49,7 @@ def test_forward_arch_param(self):

new_gumbel_choice_route = MODELS.build(new_gumbel_choice_route_cfg)

arch_param = new_gumbel_choice_route.build_arch_param()
arch_param = nn.Parameter(torch.randn(len(edges_dict)))
output = new_gumbel_choice_route.forward_arch_param(
x=x, arch_param=arch_param)
assert output is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from os.path import dirname

import torch
from mmcls.data import ClsDataSample
from mmcls.models import * # noqa: F401,F403
from mmcls.structures import ClsDataSample

from mmrazor import digit_version
from mmrazor.models.mutables import SlimmableMutableChannel
Expand Down
34 changes: 25 additions & 9 deletions tests/test_models/test_mutators/test_diff_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch.nn as nn

from mmrazor.models import * # noqa: F401,F403
from mmrazor.models.mutables import DiffMutableModule, DiffMutableOP
from mmrazor.models.mutables import DiffMutableModule
from mmrazor.models.mutators import DiffModuleMutator
from mmrazor.registry import MODELS

MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
Expand Down Expand Up @@ -101,7 +102,7 @@ def setUp(self):

def test_diff_mutator_diffop_layer(self) -> None:
model = SearchableLayer(self.MUTABLE_CFG)
mutator: DiffMutableOP = MODELS.build(self.MUTATOR_CFG)
mutator: DiffModuleMutator = MODELS.build(self.MUTATOR_CFG)

mutator.prepare_from_supernet(model)
assert list(mutator.search_groups.keys()) == [0, 1, 2]
Expand All @@ -115,7 +116,7 @@ def test_diff_mutator_diffop_model(self) -> None:
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
]
mutator: DiffMutableOP = MODELS.build(mutator_cfg)
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)

mutator.prepare_from_supernet(model)
assert list(mutator.search_groups.keys()) == [0, 1, 2]
Expand All @@ -132,7 +133,7 @@ def test_diff_mutator_diffop_model_error(self) -> None:
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3_error_key'],
]
mutator: DiffMutableOP = MODELS.build(mutator_cfg)
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)

with pytest.raises(AssertionError):
mutator.prepare_from_supernet(model)
Expand All @@ -142,7 +143,7 @@ def test_diff_mutator_diffop_alias(self) -> None:

mutator_cfg = self.MUTATOR_CFG.copy()
mutator_cfg['custom_group'] = [['op1'], ['op2'], ['op3']]
mutator: DiffMutableOP = MODELS.build(mutator_cfg)
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)

mutator.prepare_from_supernet(model)

Expand All @@ -161,7 +162,7 @@ def test_diff_mutator_alias_module_name(self) -> None:
'slayer1.op2', 'slayer2.op2',
'slayer3.op2'
], ['slayer1.op3', 'slayer2.op3']]
mutator: DiffMutableOP = MODELS.build(mutator_cfg)
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)

mutator.prepare_from_supernet(model)

Expand All @@ -179,7 +180,7 @@ def test_diff_mutator_duplicate_keys(self) -> None:
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer2.op3'],
]
mutator: DiffMutableOP = MODELS.build(mutator_cfg)
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)

with pytest.raises(AssertionError):
mutator.prepare_from_supernet(model)
Expand All @@ -193,7 +194,7 @@ def test_diff_mutator_duplicate_key_alias(self) -> None:
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
]
mutator: DiffMutableOP = MODELS.build(mutator_cfg)
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)

with pytest.raises(AssertionError):
mutator.prepare_from_supernet(model)
Expand All @@ -207,11 +208,26 @@ def test_diff_mutator_illegal_key(self) -> None:
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
]
mutator: DiffMutableOP = MODELS.build(mutator_cfg)
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)

with pytest.raises(AssertionError):
mutator.prepare_from_supernet(model)

def test_sample_and_set_choices(self):
model = SearchableModel(self.MUTABLE_CFG)

mutator_cfg = self.MUTATOR_CFG.copy()
mutator_cfg['custom_group'] = [
['slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
]
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)
mutator.prepare_from_supernet(model)
choices = mutator.sample_choices()
mutator.set_choices(choices)
self.assertTrue(len(choices) == 3)


if __name__ == '__main__':
import unittest
Expand Down

0 comments on commit 696191e

Please sign in to comment.