-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix lint * complement unittest for derived mutable * add docstring for derived mutable * add unittest for mutable value * fix logger error * fix according to comments * not dump derived mutable when export * add warning in `export_fix_subnet` * fix __mul__ in mutable value
- Loading branch information
1 parent
696191e
commit 7dca9ba
Showing
13 changed files
with
1,196 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .derived_mutable import DerivedMutable | ||
from .mutable_channel import (MutableChannel, OneShotMutableChannel, | ||
SlimmableMutableChannel) | ||
from .mutable_manage_mixin import MutableManageMixIn | ||
from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, | ||
OneShotMutableModule, OneShotMutableOP) | ||
from .mutable_value import MutableValue, OneShotMutableValue | ||
|
||
__all__ = [ | ||
'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', | ||
'DiffChoiceRoute', 'DiffMutableModule', 'MutableManageMixIn', | ||
'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel' | ||
'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel', | ||
'DerivedMutable', 'MutableValue', 'OneShotMutableValue' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,382 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import inspect | ||
import logging | ||
from collections.abc import Iterable | ||
from typing import Any, Callable, Dict, Optional, Protocol, Set, Union | ||
|
||
import torch | ||
from mmengine.logging import print_log | ||
from torch import Tensor | ||
|
||
from ..utils import make_divisible | ||
from .base_mutable import CHOICE_TYPE, BaseMutable | ||
|
||
|
||
class MutableProtocol(Protocol): # pragma: no cover | ||
"""Protocol for Mutable.""" | ||
|
||
@property | ||
def current_choice(self) -> Any: | ||
"""Current choice.""" | ||
|
||
def derive_expand_mutable(self, expand_ratio: int) -> Any: | ||
"""Derive expand mutable.""" | ||
|
||
def derive_divide_mutable(self, ratio: int, divisor: int) -> Any: | ||
"""Derive divide mutable.""" | ||
|
||
|
||
class MutableChannelProtocol(MutableProtocol): # pragma: no cover | ||
"""Protocol for MutableChannel.""" | ||
|
||
@property | ||
def current_mask(self) -> Tensor: | ||
"""Current mask.""" | ||
|
||
|
||
def _expand_choice_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: | ||
"""Helper function to build `choice_fn` for expand derived mutable.""" | ||
|
||
def fn(): | ||
return mutable.current_choice * expand_ratio | ||
|
||
return fn | ||
|
||
|
||
def _expand_mask_fn(mutable: MutableProtocol, | ||
expand_ratio: int) -> Callable: # pragma: no cover | ||
"""Helper function to build `mask_fn` for expand derived mutable.""" | ||
if not hasattr(mutable, 'current_mask'): | ||
raise ValueError('mutable must have attribute `currnet_mask`') | ||
|
||
def fn(): | ||
mask = mutable.current_mask | ||
expand_num_channels = mask.size(0) * expand_ratio | ||
expand_choice = mutable.current_choice * expand_ratio | ||
expand_mask = torch.zeros(expand_num_channels).bool() | ||
expand_mask[:expand_choice] = True | ||
|
||
return expand_mask | ||
|
||
return fn | ||
|
||
|
||
def _divide_and_divise(x: int, ratio: int, divisor: int = 8) -> int: | ||
"""Helper function for divide and divise.""" | ||
new_x = x // ratio | ||
|
||
return make_divisible(new_x, divisor) | ||
|
||
|
||
def _divide_choice_fn(mutable: MutableProtocol, | ||
ratio: int, | ||
divisor: int = 8) -> Callable: | ||
"""Helper function to build `choice_fn` for divide derived mutable.""" | ||
|
||
def fn(): | ||
return _divide_and_divise(mutable.current_choice, ratio, divisor) | ||
|
||
return fn | ||
|
||
|
||
def _divide_mask_fn(mutable: MutableProtocol, | ||
ratio: int, | ||
divisor: int = 8) -> Callable: # pragma: no cover | ||
"""Helper function to build `mask_fn` for divide derived mutable.""" | ||
if not hasattr(mutable, 'current_mask'): | ||
raise ValueError('mutable must have attribute `currnet_mask`') | ||
|
||
def fn(): | ||
mask = mutable.current_mask | ||
divide_num_channels = _divide_and_divise(mask.size(0), ratio, divisor) | ||
divide_choice = _divide_and_divise(mutable.current_choice, ratio, | ||
divisor) | ||
divide_mask = torch.zeros(divide_num_channels).bool() | ||
divide_mask[:divide_choice] = True | ||
|
||
return divide_mask | ||
|
||
return fn | ||
|
||
|
||
def _concat_choice_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable: | ||
"""Helper function to build `choice_fn` for concat derived mutable.""" | ||
|
||
def fn(): | ||
return sum((m.current_choice for m in mutables)) | ||
|
||
return fn | ||
|
||
|
||
def _concat_mask_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable: | ||
"""Helper function to build `mask_fn` for concat derived mutable.""" | ||
|
||
def fn(): | ||
return torch.cat([m.current_mask for m in mutables]) | ||
|
||
return fn | ||
|
||
|
||
class DerivedMethodMixin: | ||
"""A mixin that provides some useful method to derive mutable.""" | ||
|
||
def derive_same_mutable(self: MutableProtocol) -> 'DerivedMutable': | ||
"""Derive same mutable as the source.""" | ||
return self.derive_expand_mutable(expand_ratio=1) | ||
|
||
def derive_expand_mutable(self: MutableProtocol, | ||
expand_ratio: int) -> 'DerivedMutable': | ||
"""Derive expand mutable, usually used with `expand_ratio`.""" | ||
choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) | ||
|
||
mask_fn: Optional[Callable] = None | ||
if hasattr(self, 'current_mask'): | ||
mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio) | ||
|
||
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) | ||
|
||
def derive_divide_mutable(self: MutableProtocol, | ||
ratio: int, | ||
divisor: int = 8) -> 'DerivedMutable': | ||
"""Derive divide mutable, usually used with `make_divisable`.""" | ||
choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor) | ||
|
||
mask_fn: Optional[Callable] = None | ||
if hasattr(self, 'current_mask'): | ||
mask_fn = _divide_mask_fn(self, ratio=ratio, divisor=divisor) | ||
|
||
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) | ||
|
||
@staticmethod | ||
def derive_concat_mutable( | ||
mutables: Iterable[MutableChannelProtocol]) -> 'DerivedMutable': | ||
"""Derive concat mutable, usually used with `torch.cat`.""" | ||
for mutable in mutables: | ||
if not hasattr(mutable, 'current_mask'): | ||
raise RuntimeError('Source mutable of concat derived mutable ' | ||
'must have attribute `currnet_mask`') | ||
|
||
choice_fn = _concat_choice_fn(mutables) | ||
mask_fn = _concat_mask_fn(mutables) | ||
|
||
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) | ||
|
||
|
||
class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE], | ||
DerivedMethodMixin): | ||
"""Class for derived mutable. | ||
A derived mutable is a mutable derived from other mutables that has | ||
`current_choice` and `current_mask` attributes (if any). | ||
Note: | ||
A derived mutable does not have its own search space, so it is | ||
not legal to modify its `current_choice` or `current_mask` directly. | ||
And the only way to modify them is by modifying `current_choice` or | ||
`current_mask` in corresponding source mutables. | ||
Args: | ||
choice_fn (callable): A closure that controls how to generate | ||
`current_choice`. | ||
mask_fn (callable, optional): A closure that controls how to generate | ||
`current_mask`. Defaults to None. | ||
source_mutables (iterable, optional): Specify source mutables for this | ||
derived mutable. If the argument is None, source mutables will be | ||
traced automatically by parsing mutables in closure variables. | ||
Defaults to None. | ||
alias (str, optional): alias of the `MUTABLE`. Defaults to None. | ||
init_cfg (dict, optional): initialization configuration dict for | ||
``BaseModule``. OpenMMLab has implement 5 initializer including | ||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, | ||
and `Pretrained`. Defaults to None. | ||
Examples: | ||
>>> from mmrazor.models.mutables import OneShotMutableChannel | ||
>>> mutable_channel = OneShotMutableChannel( | ||
... num_channels=3, | ||
... candidate_choices=[1, 2, 3], | ||
... candidate_mode='number') | ||
>>> # derive expand mutable | ||
>>> derived_mutable_channel = mutable_channel * 2 | ||
>>> # source mutables will be traced automatically | ||
>>> derived_mutable_channel.source_mutables | ||
{OneShotMutableChannel(name=unbind, num_channels=3, current_choice=3, choices=[1, 2, 3], activated_channels=3, concat_mutable_name=[])} # noqa: E501 | ||
>>> # modify `current_choice` of `mutable_channel` | ||
>>> mutable_channel.current_choice = 2 | ||
>>> # `current_choice` and `current_mask` of derived mutable will be modified automatically # noqa: E501 | ||
>>> derived_mutable_channel | ||
DerivedMutable(current_choice=4, activated_channels=4, source_mutables={OneShotMutableChannel(name=unbind, num_channels=3, current_choice=2, choices=[1, 2, 3], activated_channels=2, concat_mutable_name=[])}, is_fixed=False) # noqa: E501 | ||
""" | ||
|
||
def __init__(self, | ||
choice_fn: Callable, | ||
mask_fn: Optional[Callable] = None, | ||
source_mutables: Optional[Iterable[BaseMutable]] = None, | ||
alias: Optional[str] = None, | ||
init_cfg: Optional[Dict] = None) -> None: | ||
super().__init__(alias, init_cfg) | ||
|
||
self.choice_fn = choice_fn | ||
self.mask_fn = mask_fn | ||
|
||
if source_mutables is None: | ||
source_mutables = self._trace_source_mutables() | ||
if len(source_mutables) == 0: | ||
raise RuntimeError( | ||
'Can not find source mutables automatically, ' | ||
'please provide manually.') | ||
else: | ||
source_mutables = set(source_mutables) | ||
for mutable in source_mutables: | ||
if not self.is_source_mutable(mutable): | ||
raise ValueError('Expect all mutable to be source mutable, ' | ||
f'but {mutable} is not') | ||
self.source_mutables = source_mutables | ||
|
||
# TODO | ||
# has no effect | ||
def fix_chosen(self, chosen: CHOICE_TYPE) -> None: | ||
"""Fix mutable with subnet config. | ||
Warning: | ||
Fix derived mutable will have no actually effect. | ||
""" | ||
print_log( | ||
'Trying to fix chosen for derived mutable, ' | ||
'which will have no effect.', | ||
level=logging.WARNING) | ||
|
||
def dump_chosen(self) -> CHOICE_TYPE: | ||
"""Dump information of chosen. | ||
Returns: | ||
Dict: Dumped information. | ||
""" | ||
print_log( | ||
'Trying to dump chosen for derived mutable, ' | ||
'but its value depend on the source mutables.', | ||
level=logging.WARNING) | ||
return self.current_choice | ||
|
||
@property | ||
def is_fixed(self) -> bool: | ||
"""Whether the derived mutable is fixed. | ||
Note: | ||
Depends on whether all source mutables are already fixed. | ||
""" | ||
return all(m.is_fixed for m in self.source_mutables) | ||
|
||
@is_fixed.setter | ||
def is_fixed(self, is_fixed: bool) -> bool: | ||
"""Setter of is fixed.""" | ||
raise RuntimeError( | ||
'`is_fixed` of derived mutable should not be modified directly') | ||
|
||
@property | ||
def num_choices(self) -> int: | ||
"""Number of all choices. | ||
Note: | ||
Since derive mutable does not have its own search space, the number | ||
of choices will always be `1`. | ||
Returns: | ||
int: Number of choices. | ||
""" | ||
return 1 | ||
|
||
@property | ||
def current_choice(self) -> CHOICE_TYPE: | ||
"""Current choice of derived mutable.""" | ||
return self.choice_fn() | ||
|
||
@current_choice.setter | ||
def current_choice(self, choice: CHOICE_TYPE) -> None: | ||
"""Setter of current choice. | ||
Raises: | ||
RuntimeError: Error when `current_choice` of derived mutable | ||
is modified directly. | ||
""" | ||
raise RuntimeError('Choice of drived mutable can not be set.') | ||
|
||
@property | ||
def current_mask(self) -> Tensor: | ||
"""Current mask of derived mutable.""" | ||
if self.mask_fn is None: | ||
raise RuntimeError( | ||
'`mask_fn` must be set before access `current_mask`.') | ||
return self.mask_fn() | ||
|
||
@current_mask.setter | ||
def current_mask(self, mask: Tensor) -> None: | ||
"""Setter of current mask. | ||
Raises: | ||
RuntimeError: Error when `current_mask` of derived mutable | ||
is modified directly. | ||
""" | ||
raise RuntimeError('Mask of drived mutable can not be set.') | ||
|
||
@staticmethod | ||
def _trace_source_mutables_from_closure( | ||
closure: Callable) -> Set[BaseMutable]: | ||
"""Trace source mutables from closure.""" | ||
source_mutables: Set[BaseMutable] = set() | ||
|
||
def add_mutables_dfs( | ||
mutable: Union[Iterable, BaseMutable, Dict]) -> None: | ||
nonlocal source_mutables | ||
if isinstance(mutable, BaseMutable): | ||
if isinstance(mutable, DerivedMutable): | ||
source_mutables |= mutable.source_mutables | ||
else: | ||
source_mutables.add(mutable) | ||
# dict is also iterable, should parse first | ||
elif isinstance(mutable, dict): | ||
add_mutables_dfs(mutable.values()) | ||
add_mutables_dfs(mutable.keys()) | ||
elif isinstance(mutable, Iterable): | ||
for m in mutable: | ||
add_mutables_dfs(m) | ||
|
||
noncolcal_pars = inspect.getclosurevars(closure).nonlocals | ||
add_mutables_dfs(noncolcal_pars.values()) | ||
|
||
return source_mutables | ||
|
||
def _trace_source_mutables(self) -> Set[BaseMutable]: | ||
"""Trace source mutables.""" | ||
source_mutables = self._trace_source_mutables_from_closure( | ||
self.choice_fn) | ||
if self.mask_fn is not None: | ||
source_mutables |= self._trace_source_mutables_from_closure( | ||
self.mask_fn) | ||
|
||
return source_mutables | ||
|
||
@staticmethod | ||
def is_source_mutable(mutable: object) -> bool: | ||
"""Judge whether an object is source mutable(not derived mutable). | ||
Args: | ||
mutable (object): An object. | ||
Returns: | ||
bool: Indicate whether the object is source mutable or not. | ||
""" | ||
return isinstance(mutable, BaseMutable) and \ | ||
not isinstance(mutable, DerivedMutable) | ||
|
||
# TODO | ||
# should be __str__? but can not provide info when debug | ||
def __repr__(self) -> str: # pragma: no cover | ||
s = f'{self.__class__.__name__}(' | ||
s += f'current_choice={self.current_choice}, ' | ||
if self.mask_fn is not None: | ||
s += f'activated_channels={self.current_mask.sum().item()}, ' | ||
s += f'source_mutables={self.source_mutables}, ' | ||
s += f'is_fixed={self.is_fixed})' | ||
|
||
return s |
Oops, something went wrong.