From 7dca9ba7c77a33dbb2a7fb53b294572e82d1df16 Mon Sep 17 00:00:00 2001 From: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Date: Wed, 10 Aug 2022 16:07:03 +0800 Subject: [PATCH] [Feature] Add `DerivedMutable` & `MutableValue` (#215) * fix lint * complement unittest for derived mutable * add docstring for derived mutable * add unittest for mutable value * fix logger error * fix according to comments * not dump derived mutable when export * add warning in `export_fix_subnet` * fix __mul__ in mutable value --- mmrazor/models/mutables/__init__.py | 5 +- mmrazor/models/mutables/derived_mutable.py | 382 ++++++++++++++++++ .../mutable_channel/mutable_channel.py | 6 +- .../one_shot_mutable_channel.py | 93 ++++- .../slimmable_mutable_channel.py | 1 + .../models/mutables/mutable_value/__init__.py | 4 + .../mutables/mutable_value/mutable_value.py | 236 +++++++++++ mmrazor/models/utils/__init__.py | 5 +- mmrazor/models/utils/make_divisible.py | 31 ++ mmrazor/structures/subnet/fix_subnet.py | 24 +- .../test_mutables/test_derived_mutable.py | 250 ++++++++++++ .../test_mutables/test_mutable_value.py | 136 +++++++ .../test_subnet/test_fix_subnet.py | 38 +- 13 files changed, 1196 insertions(+), 15 deletions(-) create mode 100644 mmrazor/models/mutables/derived_mutable.py create mode 100644 mmrazor/models/mutables/mutable_value/__init__.py create mode 100644 mmrazor/models/mutables/mutable_value/mutable_value.py create mode 100644 mmrazor/models/utils/make_divisible.py create mode 100644 tests/test_models/test_mutables/test_derived_mutable.py create mode 100644 tests/test_models/test_mutables/test_mutable_value.py diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 94dce2a7d..123e597ae 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -1,12 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .derived_mutable import DerivedMutable from .mutable_channel import (MutableChannel, OneShotMutableChannel, SlimmableMutableChannel) from .mutable_manage_mixin import MutableManageMixIn from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, OneShotMutableModule, OneShotMutableOP) +from .mutable_value import MutableValue, OneShotMutableValue __all__ = [ 'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', 'DiffChoiceRoute', 'DiffMutableModule', 'MutableManageMixIn', - 'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel' + 'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel', + 'DerivedMutable', 'MutableValue', 'OneShotMutableValue' ] diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py new file mode 100644 index 000000000..7cef44dcd --- /dev/null +++ b/mmrazor/models/mutables/derived_mutable.py @@ -0,0 +1,382 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import logging +from collections.abc import Iterable +from typing import Any, Callable, Dict, Optional, Protocol, Set, Union + +import torch +from mmengine.logging import print_log +from torch import Tensor + +from ..utils import make_divisible +from .base_mutable import CHOICE_TYPE, BaseMutable + + +class MutableProtocol(Protocol): # pragma: no cover + """Protocol for Mutable.""" + + @property + def current_choice(self) -> Any: + """Current choice.""" + + def derive_expand_mutable(self, expand_ratio: int) -> Any: + """Derive expand mutable.""" + + def derive_divide_mutable(self, ratio: int, divisor: int) -> Any: + """Derive divide mutable.""" + + +class MutableChannelProtocol(MutableProtocol): # pragma: no cover + """Protocol for MutableChannel.""" + + @property + def current_mask(self) -> Tensor: + """Current mask.""" + + +def _expand_choice_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: + """Helper function to build `choice_fn` for expand derived mutable.""" + + def fn(): + return mutable.current_choice * expand_ratio + + return fn + + +def _expand_mask_fn(mutable: MutableProtocol, + expand_ratio: int) -> Callable: # pragma: no cover + """Helper function to build `mask_fn` for expand derived mutable.""" + if not hasattr(mutable, 'current_mask'): + raise ValueError('mutable must have attribute `currnet_mask`') + + def fn(): + mask = mutable.current_mask + expand_num_channels = mask.size(0) * expand_ratio + expand_choice = mutable.current_choice * expand_ratio + expand_mask = torch.zeros(expand_num_channels).bool() + expand_mask[:expand_choice] = True + + return expand_mask + + return fn + + +def _divide_and_divise(x: int, ratio: int, divisor: int = 8) -> int: + """Helper function for divide and divise.""" + new_x = x // ratio + + return make_divisible(new_x, divisor) + + +def _divide_choice_fn(mutable: MutableProtocol, + ratio: int, + divisor: int = 8) -> Callable: + """Helper function to build `choice_fn` for divide derived mutable.""" + + def fn(): + return _divide_and_divise(mutable.current_choice, ratio, divisor) + + return fn + + +def _divide_mask_fn(mutable: MutableProtocol, + ratio: int, + divisor: int = 8) -> Callable: # pragma: no cover + """Helper function to build `mask_fn` for divide derived mutable.""" + if not hasattr(mutable, 'current_mask'): + raise ValueError('mutable must have attribute `currnet_mask`') + + def fn(): + mask = mutable.current_mask + divide_num_channels = _divide_and_divise(mask.size(0), ratio, divisor) + divide_choice = _divide_and_divise(mutable.current_choice, ratio, + divisor) + divide_mask = torch.zeros(divide_num_channels).bool() + divide_mask[:divide_choice] = True + + return divide_mask + + return fn + + +def _concat_choice_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable: + """Helper function to build `choice_fn` for concat derived mutable.""" + + def fn(): + return sum((m.current_choice for m in mutables)) + + return fn + + +def _concat_mask_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable: + """Helper function to build `mask_fn` for concat derived mutable.""" + + def fn(): + return torch.cat([m.current_mask for m in mutables]) + + return fn + + +class DerivedMethodMixin: + """A mixin that provides some useful method to derive mutable.""" + + def derive_same_mutable(self: MutableProtocol) -> 'DerivedMutable': + """Derive same mutable as the source.""" + return self.derive_expand_mutable(expand_ratio=1) + + def derive_expand_mutable(self: MutableProtocol, + expand_ratio: int) -> 'DerivedMutable': + """Derive expand mutable, usually used with `expand_ratio`.""" + choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) + + mask_fn: Optional[Callable] = None + if hasattr(self, 'current_mask'): + mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio) + + return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) + + def derive_divide_mutable(self: MutableProtocol, + ratio: int, + divisor: int = 8) -> 'DerivedMutable': + """Derive divide mutable, usually used with `make_divisable`.""" + choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor) + + mask_fn: Optional[Callable] = None + if hasattr(self, 'current_mask'): + mask_fn = _divide_mask_fn(self, ratio=ratio, divisor=divisor) + + return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) + + @staticmethod + def derive_concat_mutable( + mutables: Iterable[MutableChannelProtocol]) -> 'DerivedMutable': + """Derive concat mutable, usually used with `torch.cat`.""" + for mutable in mutables: + if not hasattr(mutable, 'current_mask'): + raise RuntimeError('Source mutable of concat derived mutable ' + 'must have attribute `currnet_mask`') + + choice_fn = _concat_choice_fn(mutables) + mask_fn = _concat_mask_fn(mutables) + + return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) + + +class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE], + DerivedMethodMixin): + """Class for derived mutable. + + A derived mutable is a mutable derived from other mutables that has + `current_choice` and `current_mask` attributes (if any). + + Note: + A derived mutable does not have its own search space, so it is + not legal to modify its `current_choice` or `current_mask` directly. + And the only way to modify them is by modifying `current_choice` or + `current_mask` in corresponding source mutables. + + Args: + choice_fn (callable): A closure that controls how to generate + `current_choice`. + mask_fn (callable, optional): A closure that controls how to generate + `current_mask`. Defaults to None. + source_mutables (iterable, optional): Specify source mutables for this + derived mutable. If the argument is None, source mutables will be + traced automatically by parsing mutables in closure variables. + Defaults to None. + alias (str, optional): alias of the `MUTABLE`. Defaults to None. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. Defaults to None. + + Examples: + >>> from mmrazor.models.mutables import OneShotMutableChannel + >>> mutable_channel = OneShotMutableChannel( + ... num_channels=3, + ... candidate_choices=[1, 2, 3], + ... candidate_mode='number') + >>> # derive expand mutable + >>> derived_mutable_channel = mutable_channel * 2 + >>> # source mutables will be traced automatically + >>> derived_mutable_channel.source_mutables + {OneShotMutableChannel(name=unbind, num_channels=3, current_choice=3, choices=[1, 2, 3], activated_channels=3, concat_mutable_name=[])} # noqa: E501 + >>> # modify `current_choice` of `mutable_channel` + >>> mutable_channel.current_choice = 2 + >>> # `current_choice` and `current_mask` of derived mutable will be modified automatically # noqa: E501 + >>> derived_mutable_channel + DerivedMutable(current_choice=4, activated_channels=4, source_mutables={OneShotMutableChannel(name=unbind, num_channels=3, current_choice=2, choices=[1, 2, 3], activated_channels=2, concat_mutable_name=[])}, is_fixed=False) # noqa: E501 + """ + + def __init__(self, + choice_fn: Callable, + mask_fn: Optional[Callable] = None, + source_mutables: Optional[Iterable[BaseMutable]] = None, + alias: Optional[str] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(alias, init_cfg) + + self.choice_fn = choice_fn + self.mask_fn = mask_fn + + if source_mutables is None: + source_mutables = self._trace_source_mutables() + if len(source_mutables) == 0: + raise RuntimeError( + 'Can not find source mutables automatically, ' + 'please provide manually.') + else: + source_mutables = set(source_mutables) + for mutable in source_mutables: + if not self.is_source_mutable(mutable): + raise ValueError('Expect all mutable to be source mutable, ' + f'but {mutable} is not') + self.source_mutables = source_mutables + + # TODO + # has no effect + def fix_chosen(self, chosen: CHOICE_TYPE) -> None: + """Fix mutable with subnet config. + + Warning: + Fix derived mutable will have no actually effect. + """ + print_log( + 'Trying to fix chosen for derived mutable, ' + 'which will have no effect.', + level=logging.WARNING) + + def dump_chosen(self) -> CHOICE_TYPE: + """Dump information of chosen. + + Returns: + Dict: Dumped information. + """ + print_log( + 'Trying to dump chosen for derived mutable, ' + 'but its value depend on the source mutables.', + level=logging.WARNING) + return self.current_choice + + @property + def is_fixed(self) -> bool: + """Whether the derived mutable is fixed. + + Note: + Depends on whether all source mutables are already fixed. + """ + return all(m.is_fixed for m in self.source_mutables) + + @is_fixed.setter + def is_fixed(self, is_fixed: bool) -> bool: + """Setter of is fixed.""" + raise RuntimeError( + '`is_fixed` of derived mutable should not be modified directly') + + @property + def num_choices(self) -> int: + """Number of all choices. + + Note: + Since derive mutable does not have its own search space, the number + of choices will always be `1`. + + Returns: + int: Number of choices. + """ + return 1 + + @property + def current_choice(self) -> CHOICE_TYPE: + """Current choice of derived mutable.""" + return self.choice_fn() + + @current_choice.setter + def current_choice(self, choice: CHOICE_TYPE) -> None: + """Setter of current choice. + + Raises: + RuntimeError: Error when `current_choice` of derived mutable + is modified directly. + """ + raise RuntimeError('Choice of drived mutable can not be set.') + + @property + def current_mask(self) -> Tensor: + """Current mask of derived mutable.""" + if self.mask_fn is None: + raise RuntimeError( + '`mask_fn` must be set before access `current_mask`.') + return self.mask_fn() + + @current_mask.setter + def current_mask(self, mask: Tensor) -> None: + """Setter of current mask. + + Raises: + RuntimeError: Error when `current_mask` of derived mutable + is modified directly. + """ + raise RuntimeError('Mask of drived mutable can not be set.') + + @staticmethod + def _trace_source_mutables_from_closure( + closure: Callable) -> Set[BaseMutable]: + """Trace source mutables from closure.""" + source_mutables: Set[BaseMutable] = set() + + def add_mutables_dfs( + mutable: Union[Iterable, BaseMutable, Dict]) -> None: + nonlocal source_mutables + if isinstance(mutable, BaseMutable): + if isinstance(mutable, DerivedMutable): + source_mutables |= mutable.source_mutables + else: + source_mutables.add(mutable) + # dict is also iterable, should parse first + elif isinstance(mutable, dict): + add_mutables_dfs(mutable.values()) + add_mutables_dfs(mutable.keys()) + elif isinstance(mutable, Iterable): + for m in mutable: + add_mutables_dfs(m) + + noncolcal_pars = inspect.getclosurevars(closure).nonlocals + add_mutables_dfs(noncolcal_pars.values()) + + return source_mutables + + def _trace_source_mutables(self) -> Set[BaseMutable]: + """Trace source mutables.""" + source_mutables = self._trace_source_mutables_from_closure( + self.choice_fn) + if self.mask_fn is not None: + source_mutables |= self._trace_source_mutables_from_closure( + self.mask_fn) + + return source_mutables + + @staticmethod + def is_source_mutable(mutable: object) -> bool: + """Judge whether an object is source mutable(not derived mutable). + + Args: + mutable (object): An object. + + Returns: + bool: Indicate whether the object is source mutable or not. + """ + return isinstance(mutable, BaseMutable) and \ + not isinstance(mutable, DerivedMutable) + + # TODO + # should be __str__? but can not provide info when debug + def __repr__(self) -> str: # pragma: no cover + s = f'{self.__class__.__name__}(' + s += f'current_choice={self.current_choice}, ' + if self.mask_fn is not None: + s += f'activated_channels={self.current_mask.sum().item()}, ' + s += f'source_mutables={self.source_mutables}, ' + s += f'is_fixed={self.is_fixed})' + + return s diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel.py b/mmrazor/models/mutables/mutable_channel/mutable_channel.py index e0bbf62d9..f3ba2063e 100644 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel.py @@ -5,9 +5,11 @@ import torch from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE, BaseMutable +from ..derived_mutable import DerivedMethodMixin -class MutableChannel(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]): +class MutableChannel(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE], + DerivedMethodMixin): """A type of ``MUTABLES`` for single path supernet such as AutoSlim. In single path supernet, each module only has one choice invoked at the same time. A path is obtained by sampling all the available choices. It is the @@ -31,6 +33,7 @@ def __init__(self, num_channels: int, **kwargs): # outputs, we add the mutable out of these modules to the # `concat_parent_mutables` of this module. self.concat_parent_mutables: List[MutableChannel] = list() + self.name = 'unbind' @property def same_mutables(self): @@ -104,7 +107,6 @@ def fix_chosen(self, chosen: CHOSEN_TYPE) -> None: # TODO # should fixed op still have candidate_choices? - self._candidate_choices = [chosen] self._chosen = chosen self.is_fixed = True diff --git a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py index 58327ecdf..7f6eea3ad 100644 --- a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py @@ -1,15 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional, Union import numpy as np import torch from mmrazor.registry import MODELS +from ..derived_mutable import DerivedMutable from .mutable_channel import MutableChannel @MODELS.register_module() -class OneShotMutableChannel(MutableChannel[int, int]): +class OneShotMutableChannel(MutableChannel[int, Dict]): """A type of ``MUTABLES`` for single path supernet such as AutoSlim. In single path supernet, each module only has one choice invoked at the same time. A path is obtained by sampling all the available choices. It is the @@ -36,7 +37,7 @@ class OneShotMutableChannel(MutableChannel[int, int]): def __init__(self, num_channels: int, - candidate_choices: List, + candidate_choices: List[Union[int, float]], candidate_mode: str = 'ratio', init_cfg: Optional[Dict] = None): super(OneShotMutableChannel, self).__init__( @@ -108,7 +109,7 @@ def current_choice(self, choice: int): self._current_choice = choice @property - def choices(self) -> List[int]: + def choices(self) -> List: """list: all choices. """ if self._candidate_mode == 'number': return self._candidate_choices @@ -129,5 +130,85 @@ def convert_choice_to_mask(self, choice: int) -> torch.Tensor: mask[:num_channels] = True return mask - def dump_chosen(self) -> int: - return self.current_choice + def dump_chosen(self) -> Dict: + assert self.current_choice is not None + + return dict( + current_choice=self.current_choice, + origin_channels=self.num_channels) + + def fix_chosen(self, dumped_chosen: Dict) -> None: + if self.is_fixed: + raise RuntimeError('OneShotMutableChannel can not be fixed twice') + + current_choice = dumped_chosen['current_choice'] + origin_channels = dumped_chosen['origin_channels'] + + assert current_choice <= origin_channels + assert origin_channels == self.num_channels + + self.current_choice = current_choice + self.is_fixed = True + + def __repr__(self): + concat_mutable_name = [ + mutable.name for mutable in self.concat_parent_mutables + ] + repr_str = self.__class__.__name__ + repr_str += f'(name={self.name}, ' + repr_str += f'num_channels={self.num_channels}, ' + repr_str += f'current_choice={self.current_choice}, ' + repr_str += f'choices={self.choices}, ' + repr_str += f'activated_channels={self.current_mask.sum().item()}, ' + repr_str += f'concat_mutable_name={concat_mutable_name})' + return repr_str + + def __rmul__(self, other) -> DerivedMutable: + return self * other + + def __mul__(self, other) -> DerivedMutable: + if isinstance(other, int): + return self.derive_expand_mutable(other) + + from ..mutable_value import OneShotMutableValue + + def expand_choice_fn(mutable1: 'OneShotMutableChannel', + mutable2: OneShotMutableValue) -> Callable: + + def fn(): + return mutable1.current_choice * mutable2.current_choice + + return fn + + def expand_mask_fn(mutable1: 'OneShotMutableChannel', + mutable2: OneShotMutableValue) -> Callable: + + def fn(): + mask = mutable1.current_mask + max_expand_ratio = mutable2.max_choice + current_expand_ratio = mutable2.current_choice + expand_num_channels = mask.size(0) * max_expand_ratio + + expand_choice = mutable1.current_choice * current_expand_ratio + expand_mask = torch.zeros(expand_num_channels).bool() + expand_mask[:expand_choice] = True + + return expand_mask + + return fn + + if isinstance(other, OneShotMutableValue): + return DerivedMutable( + choice_fn=expand_choice_fn(self, other), + mask_fn=expand_mask_fn(self, other)) + + raise TypeError(f'Unsupported type {type(other)} for mul!') + + def __floordiv__(self, other) -> DerivedMutable: + if isinstance(other, int): + return self.derive_divide_mutable(other) + if isinstance(other, tuple): + assert len(other) == 2 + return self.derive_divide_mutable(*other) + + raise TypeError(f'Unsupported type {type(other)} for div!') diff --git a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py index dda61814a..ebf8b41ef 100644 --- a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py @@ -73,6 +73,7 @@ def fix_chosen(self, dumped_chosen: Dict) -> None: # TODO # remove after remove `current_choice` self.current_choice = self._candidate_choices.index(chosen) + self._candidate_choices = [chosen] super().fix_chosen(chosen) diff --git a/mmrazor/models/mutables/mutable_value/__init__.py b/mmrazor/models/mutables/mutable_value/__init__.py new file mode 100644 index 000000000..f83c93fe9 --- /dev/null +++ b/mmrazor/models/mutables/mutable_value/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mutable_value import MutableValue, OneShotMutableValue + +__all__ = ['MutableValue', 'OneShotMutableValue'] diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py new file mode 100644 index 000000000..748d83e78 --- /dev/null +++ b/mmrazor/models/mutables/mutable_value/mutable_value.py @@ -0,0 +1,236 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import Any, Dict, List, Optional, Tuple, Union + +from mmrazor.registry import MODELS +from ..base_mutable import BaseMutable +from ..derived_mutable import DerivedMethodMixin, DerivedMutable + + +@MODELS.register_module() +class MutableValue(BaseMutable[Any, Dict], DerivedMethodMixin): + """Base class for mutable value. + + A mutable value is actually a mutable that adds some functionality to a + list containing objects of the same type. + + Args: + value_list (list): List of value, each value must have the same type. + default_value (any, optional): Default value, must be one in + `value_list`. Default to None. + alias (str, optional): alias of the `MUTABLE`. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ + + def __init__(self, + value_list: List[Any], + default_value: Optional[Any] = None, + alias: Optional[str] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(alias, init_cfg) + + self._check_is_same_type(value_list) + self._value_list = value_list + + if default_value is None: + default_value = value_list[0] + self.current_choice = default_value + + @staticmethod + def _check_is_same_type(value_list: List[Any]) -> None: + """Check whether value in `value_list` has the same type.""" + if len(value_list) == 1: + return + + for i in range(1, len(value_list)): + is_same_type = type(value_list[i - 1]) is \ + type(value_list[i]) # noqa: E721 + if not is_same_type: + raise TypeError( + 'All elements in `value_list` must have same ' + f'type, but both types {type(value_list[i-1])} ' + f'and type {type(value_list[i])} exist.') + + @property + def choices(self) -> List[Any]: + """List of choices.""" + return self._value_list + + def fix_chosen(self, chosen: Dict[str, Any]) -> None: + """Fix mutable value with subnet config. + + Args: + chosen (dict): the information of chosen. + """ + if self.is_fixed: + raise RuntimeError('MutableValue can not be fixed twice') + + all_choices = chosen['all_choices'] + current_choice = chosen['current_choice'] + + assert all_choices == self.choices, \ + f'Expect choices to be: {self.choices}, but got: {all_choices}' + assert current_choice in self.choices + + self.current_choice = current_choice + self.is_fixed = True + + def dump_chosen(self) -> Dict[str, Any]: + """Dump information of chosen. + + Returns: + Dict[str, Any]: Dumped information. + """ + return dict( + current_choice=self.current_choice, all_choices=self.choices) + + @property + def num_choices(self) -> int: + """Number of all choices. + + Returns: + int: Number of choices. + """ + return len(self.choices) + + @property + def current_choice(self) -> Optional[Any]: + """Current choice of mutable value.""" + return self._current_choice + + @current_choice.setter + def current_choice(self, choice: Any) -> Any: + """Setter of current choice.""" + if choice not in self.choices: + raise ValueError(f'Expected choice in: {self.choices}, ' + f'but got: {choice}') + + self._current_choice = choice + + def __rmul__(self, other) -> DerivedMutable: + """Please refer to method :func:`__mul__`.""" + return self * other + + def __mul__(self, other: int) -> DerivedMutable: + """Overload `*` operator. + + Args: + other (int): Expand ratio. + + Returns: + DerivedMutable: Derived expand mutable. + """ + if isinstance(other, int): + return self.derive_expand_mutable(other) + + raise TypeError(f'Unsupported type {type(other)} for mul!') + + def __floordiv__(self, other: Union[int, Tuple[int, + int]]) -> DerivedMutable: + """Overload `//` operator. + + Args: + other: (int, tuple): divide ratio for int or + (divide ratio, divisor) for tuple. + + Returns: + DerivedMutable: Derived divide mutable. + """ + if isinstance(other, int): + return self.derive_divide_mutable(other) + if isinstance(other, tuple): + assert len(other) == 2 + return self.derive_divide_mutable(*other) + + raise TypeError(f'Unsupported type {type(other)} for div!') + + def __repr__(self) -> str: + s = self.__class__.__name__ + s += f'(value_list={self._value_list}, ' + s += f'current_choice={self.current_choice})' + + return s + + +# TODO +# 1. use comparable for type hint +# 2. use mixin +@MODELS.register_module() +class OneShotMutableValue(MutableValue): + """Class for one-shot mutable value. + + one-shot mutable value provides `sample_choice` method and `min_choice`, + `max_choice` properties on the top of mutable value. + + Args: + value_list (list): List of value, each value must have the same type. + default_value (any, optional): Default value, must be one in + `value_list`. Default to None. + alias (str, optional): alias of the `MUTABLE`. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ + + def __init__(self, + value_list: List[Any], + default_value: Optional[Any] = None, + alias: Optional[str] = None, + init_cfg: Optional[Dict] = None) -> None: + value_list = sorted(value_list) + # set default value as max value + if default_value is None: + default_value = value_list[-1] + + super().__init__( + value_list=value_list, + default_value=default_value, + alias=alias, + init_cfg=init_cfg) + + def sample_choice(self) -> Any: + """Random sampling from choices. + + Returns: + Any: Selected choice. + """ + return random.choice(self.choices) + + @property + def max_choice(self) -> Any: + """Max choice of all choices. + + Returns: + Any: Max choice. + """ + return self.choices[-1] + + @property + def min_choice(self) -> Any: + """Min choice of all choices. + + Returns: + Any: Min choice. + """ + return self.choices[0] + + def __mul__(self, other) -> DerivedMutable: + """Overload `*` operator. + + Args: + other (int, OneShotMutableChannel): Expand ratio or + OneShotMutableChannel. + + Returns: + DerivedMutable: Derived expand mutable. + """ + from ..mutable_channel import OneShotMutableChannel + + if isinstance(other, OneShotMutableChannel): + return other * self + + return super().__mul__(other) diff --git a/mmrazor/models/utils/__init__.py b/mmrazor/models/utils/__init__.py index 7a477f6dd..fd83be434 100644 --- a/mmrazor/models/utils/__init__.py +++ b/mmrazor/models/utils/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .make_divisible import make_divisible from .misc import add_prefix from .optim_wrapper import reinitialize_optim_wrapper_count_status -__all__ = ['add_prefix', 'reinitialize_optim_wrapper_count_status'] +__all__ = [ + 'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible' +] diff --git a/mmrazor/models/utils/make_divisible.py b/mmrazor/models/utils/make_divisible.py new file mode 100644 index 000000000..5056aeb15 --- /dev/null +++ b/mmrazor/models/utils/make_divisible.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + + +def make_divisible(value: int, + divisor: int, + min_value: Optional[int] = None, + min_ratio: float = 0.9) -> int: + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index 147a26e7b..ead21413e 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging + import mmcv +from mmengine.logging import print_log from torch import nn from mmrazor.utils import FixMutable, ValidFixMutable @@ -31,6 +34,7 @@ def load_fix_subnet(model: nn.Module, raise TypeError('fix_mutable should be a `str` or `dict`' f'but got {type(fix_mutable)}') # Avoid circular import + from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables.base_mutable import BaseMutable for name, module in model.named_modules(): @@ -46,9 +50,11 @@ def load_fix_subnet(model: nn.Module, chosen = fix_mutable.get(alias, None) else: mutable_name = name.lstrip(prefix) - assert mutable_name in fix_mutable, \ - f'The module name {mutable_name} is not in ' \ - 'fix_mutable, please check your `fix_mutable`.' + if mutable_name not in fix_mutable and \ + not isinstance(module, DerivedMutable): + raise RuntimeError( + f'The module name {mutable_name} is not in ' + 'fix_mutable, please check your `fix_mutable`.') chosen = fix_mutable.get(mutable_name, None) module.fix_chosen(chosen) @@ -56,15 +62,25 @@ def load_fix_subnet(model: nn.Module, _dynamic_to_static(model) -def export_fix_subnet(model: nn.Module) -> FixMutable: +def export_fix_subnet(model: nn.Module, + dump_derived_mutable: bool = False) -> FixMutable: """Export subnet that can be loaded by :func:`load_fix_subnet`.""" + if dump_derived_mutable: + print_log( + 'Trying to dump information of all derived mutables, ' + 'this might harm readability of the exported configurations.', + level=logging.WARNING) # Avoid circular import + from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables.base_mutable import BaseMutable fix_subnet = dict() for name, module in model.named_modules(): if isinstance(module, BaseMutable): + if isinstance(module, DerivedMutable) and not dump_derived_mutable: + continue + assert not module.is_fixed if module.alias: fix_subnet[module.alias] = module.dump_chosen() diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py new file mode 100644 index 000000000..99da8dc71 --- /dev/null +++ b/tests/test_models/test_mutables/test_derived_mutable.py @@ -0,0 +1,250 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch + +from mmrazor.models.mutables import (DerivedMutable, OneShotMutableChannel, + OneShotMutableValue) +from mmrazor.models.mutables.base_mutable import BaseMutable + + +class TestDerivedMutable(TestCase): + + def test_is_fixed(self) -> None: + mc = OneShotMutableChannel( + num_channels=10, + candidate_choices=[2, 8, 10], + candidate_mode='number') + mc.current_choice = 2 + + mv = OneShotMutableValue(value_list=[2, 3, 4]) + mv.current_choice = 3 + + derived_mutable = mc * mv + assert not derived_mutable.is_fixed + + with pytest.raises(RuntimeError): + derived_mutable.is_fixed = True + + mc.fix_chosen(mc.dump_chosen()) + assert not derived_mutable.is_fixed + mv.fix_chosen(mv.dump_chosen()) + assert derived_mutable.is_fixed + + def test_fix_dump_chosen(self) -> None: + mv = OneShotMutableValue(value_list=[2, 3, 4]) + mv.current_choice = 3 + + derived_mutable = mv * 2 + assert derived_mutable.dump_chosen() == 6 + + mv.current_choice = 4 + assert derived_mutable.dump_chosen() == 8 + + # nothing will happen + derived_mutable.fix_chosen(derived_mutable.dump_chosen()) + + def test_derived_same_mutable(self) -> None: + mc = OneShotMutableChannel( + num_channels=3, + candidate_choices=[1, 2, 3], + candidate_mode='number') + mc_derived = mc.derive_same_mutable() + assert mc_derived.source_mutables == {mc} + + mc.current_choice = 2 + assert mc_derived.current_choice == 2 + assert torch.equal(mc_derived.current_mask, + torch.tensor([1, 1, 0], dtype=torch.bool)) + + def test_mutable_concat_derived(self) -> None: + mc1 = OneShotMutableChannel( + num_channels=3, candidate_choices=[1, 3], candidate_mode='number') + mc2 = OneShotMutableChannel( + num_channels=4, candidate_choices=[1, 4], candidate_mode='number') + ms = [mc1, mc2] + + mc_derived = DerivedMutable.derive_concat_mutable(ms) + assert mc_derived.source_mutables == set(ms) + + mc1.current_choice = 1 + mc2.current_choice = 4 + assert mc_derived.current_choice == 5 + assert torch.equal( + mc_derived.current_mask, + torch.tensor([1, 0, 0, 1, 1, 1, 1], dtype=torch.bool)) + + mc1.current_choice = 1 + mc2.current_choice = 1 + assert mc_derived.current_choice == 2 + assert torch.equal( + mc_derived.current_mask, + torch.tensor([1, 0, 0, 1, 0, 0, 0], dtype=torch.bool)) + + mv = OneShotMutableValue(value_list=[1, 2, 3]) + ms = [mc1, mv] + with pytest.raises(RuntimeError): + _ = DerivedMutable.derive_concat_mutable(ms) + + def test_mutable_channel_derived(self) -> None: + mc = OneShotMutableChannel( + num_channels=3, + candidate_choices=[1, 2, 3], + candidate_mode='number') + mc_derived = mc * 3 + assert mc_derived.source_mutables == {mc} + + mc.current_choice = 1 + assert mc_derived.current_choice == 3 + assert torch.equal( + mc_derived.current_mask, + torch.tensor([1, 1, 1, 0, 0, 0, 0, 0, 0], dtype=torch.bool)) + + mc.current_choice = 2 + assert mc_derived.current_choice == 6 + assert torch.equal( + mc_derived.current_mask, + torch.tensor([1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=torch.bool)) + + with pytest.raises(RuntimeError): + mc_derived.current_mask = torch.ones( + mc_derived.current_mask.size()) + + def test_mutable_divide(self) -> None: + mc = OneShotMutableChannel( + num_channels=128, + candidate_choices=[112, 120, 128], + candidate_mode='number') + mc_derived = mc // 8 + assert mc_derived.source_mutables == {mc} + + mc.current_choice = 128 + assert mc_derived.current_choice == 16 + assert torch.equal(mc_derived.current_mask, + torch.ones(16, dtype=torch.bool)) + mc.current_choice = 120 + assert mc_derived.current_choice == 16 + assert torch.equal(mc_derived.current_mask, + torch.ones(16, dtype=torch.bool)) + + mv = OneShotMutableValue(value_list=[112, 120, 128]) + mv_derived = mv // 8 + assert mv_derived.source_mutables == {mv} + + mv.current_choice == 128 + assert mv_derived.current_choice == 16 + mv.current_choice == 120 + assert mv_derived.current_choice == 16 + + def test_source_mutables(self) -> None: + useless_fn = lambda x: x # noqa: E731 + with pytest.raises(RuntimeError): + _ = DerivedMutable(choice_fn=useless_fn) + + mc1 = OneShotMutableChannel( + num_channels=3, candidate_choices=[1, 3], candidate_mode='number') + mc2 = OneShotMutableChannel( + num_channels=4, candidate_choices=[1, 4], candidate_mode='number') + ms = [mc1, mc2] + + mc_derived1 = DerivedMutable.derive_concat_mutable(ms) + + from mmrazor.models.mutables.derived_mutable import (_concat_choice_fn, + _concat_mask_fn) + mc_derived2 = DerivedMutable( + choice_fn=_concat_choice_fn(ms), + mask_fn=_concat_mask_fn(ms), + source_mutables=ms) + assert mc_derived1.source_mutables == mc_derived2.source_mutables + + dd_mutable = mc_derived1.derive_same_mutable() + assert dd_mutable.source_mutables == mc_derived1.source_mutables + + with pytest.raises(ValueError): + _ = DerivedMutable( + choice_fn=lambda x: x, source_mutables=[mc_derived1]) + + def dict_closure_fn(x, y): + + def fn(): + nonlocal x, y + + return fn + + ddd_mutable = DerivedMutable( + choice_fn=dict_closure_fn({ + mc1: [2, 3], + mc2: 2 + }, None), + mask_fn=dict_closure_fn({2: [mc1, mc2]}, {3: dd_mutable})) + assert ddd_mutable.source_mutables == mc_derived1.source_mutables + + mc3 = OneShotMutableChannel( + num_channels=4, candidate_choices=[2, 4], candidate_mode='number') + dddd_mutable = DerivedMutable( + choice_fn=dict_closure_fn({ + mc1: [2, 3], + mc2: 2 + }, []), + mask_fn=dict_closure_fn({2: [mc1, mc2, mc3]}, {3: dd_mutable})) + assert dddd_mutable.source_mutables == {mc1, mc2, mc3} + + def test_nested_mutables(self) -> None: + source_a = OneShotMutableChannel( + num_channels=2, candidate_choices=[1, 2], candidate_mode='number') + source_b = OneShotMutableChannel( + num_channels=3, candidate_choices=[2, 3], candidate_mode='number') + + # derive from + derived_c = source_a * 1 + concat_mutables = [source_b, derived_c] + derived_d = DerivedMutable.derive_concat_mutable(concat_mutables) + concat_mutables = [derived_c, derived_d] + derived_e = DerivedMutable.derive_concat_mutable(concat_mutables) + + assert derived_c.source_mutables == {source_a} + assert derived_d.source_mutables == {source_a, source_b} + assert derived_e.source_mutables == {source_a, source_b} + + source_a.current_choice = 1 + source_b.current_choice = 3 + + assert derived_c.current_choice == 1 + assert torch.equal(derived_c.current_mask, + torch.tensor([1, 0], dtype=torch.bool)) + + assert derived_d.current_choice == 4 + assert torch.equal(derived_d.current_mask, + torch.tensor([1, 1, 1, 1, 0], dtype=torch.bool)) + + assert derived_e.current_choice == 5 + assert torch.equal( + derived_e.current_mask, + torch.tensor([1, 0, 1, 1, 1, 1, 0], dtype=torch.bool)) + + +@pytest.mark.parametrize('expand_ratio', [1, 2, 3]) +def test_derived_expand_mutable(expand_ratio: int) -> None: + mv = OneShotMutableValue(value_list=[3, 5, 7]) + + mv_derived = mv * expand_ratio + assert mv_derived.source_mutables == {mv} + + assert isinstance(mv_derived, BaseMutable) + assert isinstance(mv_derived, DerivedMutable) + assert not mv_derived.is_fixed + assert mv_derived.num_choices == 1 + + mv.current_choice = mv.max_choice + assert mv_derived.current_choice == mv.current_choice * expand_ratio + mv.current_choice = mv.min_choice + assert mv_derived.current_choice == mv.current_choice * expand_ratio + + with pytest.raises(RuntimeError): + mv_derived.current_choice = 123 + with pytest.raises(RuntimeError): + _ = mv_derived.current_mask + + mv.current_choice = 5 + assert mv_derived.current_choice == 5 * expand_ratio diff --git a/tests/test_models/test_mutables/test_mutable_value.py b/tests/test_models/test_mutables/test_mutable_value.py new file mode 100644 index 000000000..0b5ed7947 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_value.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +import pytest +import torch + +from mmrazor.models.mutables import (MutableValue, OneShotMutableChannel, + OneShotMutableValue) + + +class TestMutableValue(TestCase): + + def test_init_mutable_value(self) -> None: + value_list = [2, 4, 6] + mv = MutableValue(value_list=value_list) + assert mv.current_choice == 2 + assert mv.num_choices == 3 + + mv = MutableValue(value_list=value_list, default_value=4) + assert mv.current_choice == 4 + + with pytest.raises(ValueError): + mv = MutableValue(value_list=value_list, default_value=5) + + mv = MutableValue(value_list=[2]) + assert mv.current_choice == 2 + assert mv.choices == [2] + + with pytest.raises(TypeError): + mv = MutableValue(value_list=[2, 3.2]) + + def test_init_one_shot_mutable_value(self) -> None: + value_list = [6, 4, 2] + mv = OneShotMutableValue(value_list=value_list) + assert mv.current_choice == 6 + assert mv.choices == [2, 4, 6] + + mv = OneShotMutableValue(value_list=value_list, default_value=4) + assert mv.current_choice == 4 + + def test_fix_chosen(self) -> None: + mv = MutableValue([2, 3, 4]) + chosen = mv.dump_chosen() + assert chosen == { + 'current_choice': mv.current_choice, + 'all_choices': mv.choices + } + + chosen['current_choice'] = 5 + with pytest.raises(AssertionError): + mv.fix_chosen(chosen) + + chosen_copied = copy.deepcopy(chosen) + chosen_copied['all_choices'] = [1, 2, 3] + with pytest.raises(AssertionError): + mv.fix_chosen(chosen_copied) + + chosen['current_choice'] = 3 + mv.fix_chosen(chosen) + assert mv.current_choice == 3 + + with pytest.raises(RuntimeError): + mv.fix_chosen(chosen) + + def test_one_shot_mutable_value_sample(self) -> None: + mv = OneShotMutableValue(value_list=[2, 3, 4]) + assert mv.max_choice == 4 + assert mv.min_choice == 2 + + for _ in range(100): + assert mv.sample_choice() in mv.choices + + def test_mul(self) -> None: + mv = MutableValue(value_list=[1, 2, 3], default_value=3) + mul_derived_mv = mv * 2 + rmul_derived_mv = 2 * mv + + assert mul_derived_mv.current_choice == 6 + assert rmul_derived_mv.current_choice == 6 + + mv.current_choice = 2 + assert mul_derived_mv.current_choice == 4 + assert rmul_derived_mv.current_choice == 4 + + with pytest.raises(TypeError): + _ = mv * 1.2 + + mv = MutableValue(value_list=[1, 2, 3], default_value=3) + mc = OneShotMutableChannel( + num_channels=4, candidate_choices=[2, 4], candidate_mode='number') + + with pytest.raises(TypeError): + _ = mc * mv + with pytest.raises(TypeError): + _ = mv * mc + + mv = OneShotMutableValue(value_list=[1, 2, 3], default_value=3) + mc.current_choice = 2 + + derived1 = mc * mv + derived2 = mv * mc + + assert derived1.current_choice == 6 + assert derived2.current_choice == 6 + assert torch.equal(derived1.current_mask, derived2.current_mask) + + mv.current_choice = 2 + assert derived1.current_choice == 4 + assert derived2.current_choice == 4 + assert torch.equal(derived1.current_mask, derived2.current_mask) + + def test_floordiv(self) -> None: + mv = MutableValue(value_list=[120, 128, 136]) + derived_mv = mv // 8 + + mv.current_choice = 120 + assert derived_mv.current_choice == 16 + mv.current_choice = 128 + assert derived_mv.current_choice == 16 + + derived_mv = mv // (8, 3) + mv.current_choice = 120 + assert derived_mv.current_choice == 15 + mv.current_choice = 136 + assert derived_mv.current_choice == 18 + + with pytest.raises(TypeError): + _ = mv // 1.2 + + def test_repr(self) -> None: + value_list = [2, 4, 6] + mv = MutableValue(value_list=value_list) + + assert repr(mv) == \ + f'MutableValue(value_list={value_list}, current_choice=2)' diff --git a/tests/test_models/test_subnet/test_fix_subnet.py b/tests/test_models/test_subnet/test_fix_subnet.py index 28e691bd7..010372212 100644 --- a/tests/test_models/test_subnet/test_fix_subnet.py +++ b/tests/test_models/test_subnet/test_fix_subnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from mmrazor.models import * # noqa:F403,F401 -from mmrazor.models.mutables import OneShotMutableOP +from mmrazor.models.mutables import OneShotMutableOP, OneShotMutableValue from mmrazor.registry import MODELS from mmrazor.structures import export_fix_subnet, load_fix_subnet from mmrazor.utils import FixMutable @@ -37,6 +37,15 @@ def forward(self, x): return x +class MockModelWithDerivedMutable(nn.Module): + + def __init__(self) -> None: + super().__init__() + + self.source_mutable = OneShotMutableValue([2, 3, 4], default_value=3) + self.derived_mutable = self.source_mutable * 2 + + class TestFixSubnet(TestCase): def test_load_fix_subnet(self): @@ -63,6 +72,11 @@ def test_load_fix_subnet(self): model = MockModel() load_fix_subnet(model, fix_subnet=10) + model = MockModel() + fix_subnet.pop('mutable1') + with pytest.raises(RuntimeError): + load_fix_subnet(model, fix_subnet) + def test_export_fix_subnet(self): # get FixSubnet fix_subnet = { @@ -82,3 +96,25 @@ def test_export_fix_subnet(self): exported_fix_subnet = export_fix_subnet(model) self.assertDictEqual(fix_subnet, exported_fix_subnet) + + def test_export_fix_subnet_with_derived_mutable(self) -> None: + model = MockModelWithDerivedMutable() + fix_subnet = export_fix_subnet(model) + self.assertDictEqual( + fix_subnet, {'source_mutable': model.source_mutable.dump_chosen()}) + fix_subnet['source_mutable']['current_choice'] = 4 + load_fix_subnet(model, fix_subnet) + assert model.source_mutable.current_choice == 4 + assert model.derived_mutable.current_choice == 8 + + model = MockModelWithDerivedMutable() + fix_subnet = export_fix_subnet(model, dump_derived_mutable=True) + self.assertDictEqual( + fix_subnet, { + 'source_mutable': model.source_mutable.dump_chosen(), + 'derived_mutable': model.derived_mutable.dump_chosen() + }) + fix_subnet['source_mutable']['current_choice'] = 2 + load_fix_subnet(model, fix_subnet) + assert model.source_mutable.current_choice == 2 + assert model.derived_mutable.current_choice == 4