-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add DerivedMutable
& MutableValue
#215
Changes from all commits
b897cfd
2313ba2
20a06ce
1504bf9
eaf57a1
953dfb7
fd7e4f9
e9d1064
dff284e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .derived_mutable import DerivedMutable | ||
from .mutable_channel import (MutableChannel, OneShotMutableChannel, | ||
SlimmableMutableChannel) | ||
from .mutable_manage_mixin import MutableManageMixIn | ||
from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, | ||
OneShotMutableModule, OneShotMutableOP) | ||
from .mutable_value import MutableValue, OneShotMutableValue | ||
|
||
__all__ = [ | ||
'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', | ||
'DiffChoiceRoute', 'DiffMutableModule', 'MutableManageMixIn', | ||
'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel' | ||
'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel', | ||
'DerivedMutable', 'MutableValue', 'OneShotMutableValue' | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,382 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import inspect | ||
import logging | ||
from collections.abc import Iterable | ||
from typing import Any, Callable, Dict, Optional, Protocol, Set, Union | ||
|
||
import torch | ||
from mmengine.logging import print_log | ||
from torch import Tensor | ||
|
||
from ..utils import make_divisible | ||
from .base_mutable import CHOICE_TYPE, BaseMutable | ||
|
||
|
||
class MutableProtocol(Protocol): # pragma: no cover | ||
"""Protocol for Mutable.""" | ||
|
||
@property | ||
def current_choice(self) -> Any: | ||
"""Current choice.""" | ||
|
||
def derive_expand_mutable(self, expand_ratio: int) -> Any: | ||
"""Derive expand mutable.""" | ||
|
||
def derive_divide_mutable(self, ratio: int, divisor: int) -> Any: | ||
"""Derive divide mutable.""" | ||
|
||
|
||
class MutableChannelProtocol(MutableProtocol): # pragma: no cover | ||
"""Protocol for MutableChannel.""" | ||
|
||
@property | ||
def current_mask(self) -> Tensor: | ||
"""Current mask.""" | ||
|
||
|
||
def _expand_choice_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: | ||
"""Helper function to build `choice_fn` for expand derived mutable.""" | ||
|
||
def fn(): | ||
return mutable.current_choice * expand_ratio | ||
|
||
return fn | ||
|
||
|
||
def _expand_mask_fn(mutable: MutableProtocol, | ||
expand_ratio: int) -> Callable: # pragma: no cover | ||
"""Helper function to build `mask_fn` for expand derived mutable.""" | ||
if not hasattr(mutable, 'current_mask'): | ||
raise ValueError('mutable must have attribute `currnet_mask`') | ||
|
||
def fn(): | ||
mask = mutable.current_mask | ||
expand_num_channels = mask.size(0) * expand_ratio | ||
expand_choice = mutable.current_choice * expand_ratio | ||
expand_mask = torch.zeros(expand_num_channels).bool() | ||
expand_mask[:expand_choice] = True | ||
|
||
return expand_mask | ||
|
||
return fn | ||
|
||
|
||
def _divide_and_divise(x: int, ratio: int, divisor: int = 8) -> int: | ||
"""Helper function for divide and divise.""" | ||
new_x = x // ratio | ||
|
||
return make_divisible(new_x, divisor) | ||
|
||
|
||
def _divide_choice_fn(mutable: MutableProtocol, | ||
ratio: int, | ||
divisor: int = 8) -> Callable: | ||
"""Helper function to build `choice_fn` for divide derived mutable.""" | ||
|
||
def fn(): | ||
return _divide_and_divise(mutable.current_choice, ratio, divisor) | ||
|
||
return fn | ||
|
||
|
||
def _divide_mask_fn(mutable: MutableProtocol, | ||
ratio: int, | ||
divisor: int = 8) -> Callable: # pragma: no cover | ||
"""Helper function to build `mask_fn` for divide derived mutable.""" | ||
if not hasattr(mutable, 'current_mask'): | ||
raise ValueError('mutable must have attribute `currnet_mask`') | ||
|
||
def fn(): | ||
mask = mutable.current_mask | ||
divide_num_channels = _divide_and_divise(mask.size(0), ratio, divisor) | ||
divide_choice = _divide_and_divise(mutable.current_choice, ratio, | ||
divisor) | ||
divide_mask = torch.zeros(divide_num_channels).bool() | ||
divide_mask[:divide_choice] = True | ||
|
||
return divide_mask | ||
|
||
return fn | ||
|
||
|
||
def _concat_choice_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable: | ||
"""Helper function to build `choice_fn` for concat derived mutable.""" | ||
|
||
def fn(): | ||
return sum((m.current_choice for m in mutables)) | ||
|
||
return fn | ||
|
||
|
||
def _concat_mask_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable: | ||
"""Helper function to build `mask_fn` for concat derived mutable.""" | ||
|
||
def fn(): | ||
return torch.cat([m.current_mask for m in mutables]) | ||
|
||
return fn | ||
|
||
|
||
class DerivedMethodMixin: | ||
"""A mixin that provides some useful method to derive mutable.""" | ||
|
||
def derive_same_mutable(self: MutableProtocol) -> 'DerivedMutable': | ||
"""Derive same mutable as the source.""" | ||
return self.derive_expand_mutable(expand_ratio=1) | ||
|
||
def derive_expand_mutable(self: MutableProtocol, | ||
expand_ratio: int) -> 'DerivedMutable': | ||
wutongshenqiu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Derive expand mutable, usually used with `expand_ratio`.""" | ||
choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) | ||
|
||
mask_fn: Optional[Callable] = None | ||
if hasattr(self, 'current_mask'): | ||
mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio) | ||
|
||
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) | ||
|
||
def derive_divide_mutable(self: MutableProtocol, | ||
ratio: int, | ||
divisor: int = 8) -> 'DerivedMutable': | ||
wutongshenqiu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Derive divide mutable, usually used with `make_divisable`.""" | ||
choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor) | ||
|
||
mask_fn: Optional[Callable] = None | ||
if hasattr(self, 'current_mask'): | ||
mask_fn = _divide_mask_fn(self, ratio=ratio, divisor=divisor) | ||
|
||
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) | ||
|
||
@staticmethod | ||
def derive_concat_mutable( | ||
mutables: Iterable[MutableChannelProtocol]) -> 'DerivedMutable': | ||
wutongshenqiu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Derive concat mutable, usually used with `torch.cat`.""" | ||
for mutable in mutables: | ||
if not hasattr(mutable, 'current_mask'): | ||
raise RuntimeError('Source mutable of concat derived mutable ' | ||
'must have attribute `currnet_mask`') | ||
|
||
choice_fn = _concat_choice_fn(mutables) | ||
mask_fn = _concat_mask_fn(mutables) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check the protocal (requires to contain There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) | ||
|
||
|
||
class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE], | ||
DerivedMethodMixin): | ||
"""Class for derived mutable. | ||
|
||
A derived mutable is a mutable derived from other mutables that has | ||
`current_choice` and `current_mask` attributes (if any). | ||
|
||
Note: | ||
A derived mutable does not have its own search space, so it is | ||
not legal to modify its `current_choice` or `current_mask` directly. | ||
And the only way to modify them is by modifying `current_choice` or | ||
`current_mask` in corresponding source mutables. | ||
|
||
Args: | ||
choice_fn (callable): A closure that controls how to generate | ||
`current_choice`. | ||
mask_fn (callable, optional): A closure that controls how to generate | ||
`current_mask`. Defaults to None. | ||
source_mutables (iterable, optional): Specify source mutables for this | ||
derived mutable. If the argument is None, source mutables will be | ||
traced automatically by parsing mutables in closure variables. | ||
Defaults to None. | ||
alias (str, optional): alias of the `MUTABLE`. Defaults to None. | ||
init_cfg (dict, optional): initialization configuration dict for | ||
``BaseModule``. OpenMMLab has implement 5 initializer including | ||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, | ||
and `Pretrained`. Defaults to None. | ||
|
||
Examples: | ||
>>> from mmrazor.models.mutables import OneShotMutableChannel | ||
>>> mutable_channel = OneShotMutableChannel( | ||
... num_channels=3, | ||
... candidate_choices=[1, 2, 3], | ||
... candidate_mode='number') | ||
>>> # derive expand mutable | ||
>>> derived_mutable_channel = mutable_channel * 2 | ||
>>> # source mutables will be traced automatically | ||
>>> derived_mutable_channel.source_mutables | ||
{OneShotMutableChannel(name=unbind, num_channels=3, current_choice=3, choices=[1, 2, 3], activated_channels=3, concat_mutable_name=[])} # noqa: E501 | ||
>>> # modify `current_choice` of `mutable_channel` | ||
>>> mutable_channel.current_choice = 2 | ||
>>> # `current_choice` and `current_mask` of derived mutable will be modified automatically # noqa: E501 | ||
>>> derived_mutable_channel | ||
DerivedMutable(current_choice=4, activated_channels=4, source_mutables={OneShotMutableChannel(name=unbind, num_channels=3, current_choice=2, choices=[1, 2, 3], activated_channels=2, concat_mutable_name=[])}, is_fixed=False) # noqa: E501 | ||
""" | ||
|
||
def __init__(self, | ||
choice_fn: Callable, | ||
mask_fn: Optional[Callable] = None, | ||
source_mutables: Optional[Iterable[BaseMutable]] = None, | ||
alias: Optional[str] = None, | ||
init_cfg: Optional[Dict] = None) -> None: | ||
super().__init__(alias, init_cfg) | ||
|
||
self.choice_fn = choice_fn | ||
self.mask_fn = mask_fn | ||
|
||
if source_mutables is None: | ||
source_mutables = self._trace_source_mutables() | ||
if len(source_mutables) == 0: | ||
raise RuntimeError( | ||
'Can not find source mutables automatically, ' | ||
'please provide manually.') | ||
else: | ||
source_mutables = set(source_mutables) | ||
for mutable in source_mutables: | ||
if not self.is_source_mutable(mutable): | ||
raise ValueError('Expect all mutable to be source mutable, ' | ||
f'but {mutable} is not') | ||
self.source_mutables = source_mutables | ||
|
||
# TODO | ||
# has no effect | ||
def fix_chosen(self, chosen: CHOICE_TYPE) -> None: | ||
"""Fix mutable with subnet config. | ||
|
||
Warning: | ||
Fix derived mutable will have no actually effect. | ||
""" | ||
print_log( | ||
'Trying to fix chosen for derived mutable, ' | ||
'which will have no effect.', | ||
level=logging.WARNING) | ||
|
||
def dump_chosen(self) -> CHOICE_TYPE: | ||
"""Dump information of chosen. | ||
|
||
Returns: | ||
Dict: Dumped information. | ||
""" | ||
print_log( | ||
'Trying to dump chosen for derived mutable, ' | ||
'but its value depend on the source mutables.', | ||
level=logging.WARNING) | ||
return self.current_choice | ||
|
||
@property | ||
def is_fixed(self) -> bool: | ||
"""Whether the derived mutable is fixed. | ||
|
||
Note: | ||
Depends on whether all source mutables are already fixed. | ||
""" | ||
return all(m.is_fixed for m in self.source_mutables) | ||
|
||
@is_fixed.setter | ||
def is_fixed(self, is_fixed: bool) -> bool: | ||
"""Setter of is fixed.""" | ||
raise RuntimeError( | ||
'`is_fixed` of derived mutable should not be modified directly') | ||
|
||
@property | ||
def num_choices(self) -> int: | ||
"""Number of all choices. | ||
|
||
Note: | ||
Since derive mutable does not have its own search space, the number | ||
of choices will always be `1`. | ||
|
||
Returns: | ||
int: Number of choices. | ||
""" | ||
return 1 | ||
|
||
@property | ||
def current_choice(self) -> CHOICE_TYPE: | ||
"""Current choice of derived mutable.""" | ||
return self.choice_fn() | ||
|
||
@current_choice.setter | ||
def current_choice(self, choice: CHOICE_TYPE) -> None: | ||
"""Setter of current choice. | ||
|
||
Raises: | ||
RuntimeError: Error when `current_choice` of derived mutable | ||
is modified directly. | ||
""" | ||
raise RuntimeError('Choice of drived mutable can not be set.') | ||
|
||
@property | ||
def current_mask(self) -> Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related methods have been overloaded to prohibit direct modification of |
||
"""Current mask of derived mutable.""" | ||
if self.mask_fn is None: | ||
raise RuntimeError( | ||
'`mask_fn` must be set before access `current_mask`.') | ||
return self.mask_fn() | ||
|
||
@current_mask.setter | ||
def current_mask(self, mask: Tensor) -> None: | ||
"""Setter of current mask. | ||
|
||
Raises: | ||
RuntimeError: Error when `current_mask` of derived mutable | ||
is modified directly. | ||
""" | ||
raise RuntimeError('Mask of drived mutable can not be set.') | ||
|
||
@staticmethod | ||
def _trace_source_mutables_from_closure( | ||
closure: Callable) -> Set[BaseMutable]: | ||
"""Trace source mutables from closure.""" | ||
source_mutables: Set[BaseMutable] = set() | ||
|
||
def add_mutables_dfs( | ||
mutable: Union[Iterable, BaseMutable, Dict]) -> None: | ||
nonlocal source_mutables | ||
if isinstance(mutable, BaseMutable): | ||
if isinstance(mutable, DerivedMutable): | ||
source_mutables |= mutable.source_mutables | ||
else: | ||
source_mutables.add(mutable) | ||
# dict is also iterable, should parse first | ||
elif isinstance(mutable, dict): | ||
add_mutables_dfs(mutable.values()) | ||
add_mutables_dfs(mutable.keys()) | ||
elif isinstance(mutable, Iterable): | ||
for m in mutable: | ||
add_mutables_dfs(m) | ||
|
||
noncolcal_pars = inspect.getclosurevars(closure).nonlocals | ||
add_mutables_dfs(noncolcal_pars.values()) | ||
|
||
return source_mutables | ||
|
||
def _trace_source_mutables(self) -> Set[BaseMutable]: | ||
"""Trace source mutables.""" | ||
source_mutables = self._trace_source_mutables_from_closure( | ||
self.choice_fn) | ||
if self.mask_fn is not None: | ||
source_mutables |= self._trace_source_mutables_from_closure( | ||
self.mask_fn) | ||
|
||
return source_mutables | ||
|
||
@staticmethod | ||
def is_source_mutable(mutable: object) -> bool: | ||
"""Judge whether an object is source mutable(not derived mutable). | ||
|
||
Args: | ||
mutable (object): An object. | ||
|
||
Returns: | ||
bool: Indicate whether the object is source mutable or not. | ||
""" | ||
return isinstance(mutable, BaseMutable) and \ | ||
not isinstance(mutable, DerivedMutable) | ||
|
||
# TODO | ||
# should be __str__? but can not provide info when debug | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could provide the link_path from DerivedMutable to Source mutables. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Provide attribute |
||
def __repr__(self) -> str: # pragma: no cover | ||
s = f'{self.__class__.__name__}(' | ||
s += f'current_choice={self.current_choice}, ' | ||
if self.mask_fn is not None: | ||
s += f'activated_channels={self.current_mask.sum().item()}, ' | ||
s += f'source_mutables={self.source_mutables}, ' | ||
s += f'is_fixed={self.is_fixed})' | ||
|
||
return s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo? ‘DerivedMutable’ -> DerivedMutable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for using str here is that class
DeriveMutable
is defined afterDerivedMethodMixin
. To directly useDerivedMutable
for type hint, the version of python should >= 3.7. More details can be found at this link