-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Model Compression] MixedMaskerPruner #3627
Conversation
|
||
|
||
class MixedPrunerMasker(WeightMasker): | ||
def __init__(self, model, pruner, maskers_config_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ask a naive question. In class OneshotPruner
, argument maskers_config_dict
is passed to MixedPrunerMasker
as dictionary. But we take it here directly. Would there be anything wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, you can find in this line, algo_kwargs
is expended,equals to key1=value1, key2=value2, ...
.
model, self, **algo_kwargs) |
counter = {} | ||
for config in config_list: | ||
assert 'masker_name' not in config, 'maskers_config_dict should be set if use masker_name' | ||
if 'pruning_algo' not in config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that if pruning_algo
is not set by user, LevelPrunerMasker will be used by default. Users won't recognize it without reading code. May be we should told user this corner case in doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it should be reminded, I will add it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated.
for mask_type in masks[layer]: | ||
assert hasattr( | ||
name2wrapper[layer], mask_type), "there is no attribute '%s' in wrapper on %s" % (mask_type, layer) | ||
setattr(name2wrapper[layer], mask_type, masks[layer][mask_type]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During forward, wrapper.weight_mask will multiple input. The releated code is here. Why don't need to set wrapper.weight_mask
here? Is it correct? If it is, where we set wrapper.weight_mask
with calculated masks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, in fact, mask_type
include weight_mask
and bias_mask
'sparsity': And(float, lambda n: 0 < n < 1), | ||
Optional('op_types'): [str], | ||
Optional('op_names'): [str], | ||
Optional('masker_name'): str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"masker_name"? in your example "pruning_algo" is used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, because we convert config_list, pop "pruning_algo" and add "masker_name"
MixedMaskerPruner support config different masker in operation level. | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None, maskers_config_dict=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is maskers_config_dict
used for?
_logger = logging.getLogger('torch pruner') | ||
|
||
|
||
class MixedPrunerMasker(WeightMasker): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the class names are odd. MixedPrunerMasker
, MixedMaskerPruner
...
config['masker_name'] = masker_name | ||
return config_list, maskers_config_dict | ||
|
||
def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None, origin_wrapper=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does mixedpruner deal with dependency group?
Depend on #3507