-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Model Compression] Add global sort for taylor pruner #3896
Changes from all commits
fea3768
23998a9
74e7731
763923e
be5f78c
8a75737
a5e600d
959422a
e0010b2
4fccc26
dea4849
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,11 +33,12 @@ class StructuredWeightMasker(WeightMasker): | |
|
||
""" | ||
|
||
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False): | ||
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False, global_sort=False): | ||
self.model = model | ||
self.pruner = pruner | ||
self.preserve_round = preserve_round | ||
self.dependency_aware = dependency_aware | ||
self.global_sort = global_sort | ||
|
||
def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs): | ||
""" | ||
|
@@ -60,7 +61,11 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs): | |
depen_kwargs: dict | ||
The kw_args for the dependency-aware mode. | ||
""" | ||
if not self.dependency_aware: | ||
if self.global_sort: | ||
# if the global_sort switch is on, calculate the mask based | ||
# on global model information | ||
return self._global_calc_mask(sparsity, wrapper, wrapper_idx) | ||
elif not self.dependency_aware: | ||
# calculate the mask in the normal way, each layer calculate its | ||
# own mask separately | ||
return self._normal_calc_mask(sparsity, wrapper, wrapper_idx) | ||
|
@@ -127,6 +132,12 @@ def _get_current_state(self, sparsity, wrapper, wrapper_idx=None): | |
# weight*mask_weight: apply base mask for iterative pruning | ||
return mask, weight * mask_weight, num_prune | ||
|
||
def _global_calc_mask(self, sparsity, wrapper, wrapper_idx=None): | ||
num_prune = self._get_global_num_prune(wrapper, wrapper_idx) | ||
mask, weight, _ = self._get_current_state( | ||
sparsity, wrapper, wrapper_idx) | ||
return self.get_mask(mask, weight, num_prune, wrapper, wrapper_idx) | ||
|
||
def _normal_calc_mask(self, sparsity, wrapper, wrapper_idx=None): | ||
""" | ||
Calculate the mask of given layer. | ||
|
@@ -477,6 +488,31 @@ def __init__(self, model, pruner, statistics_batch_num=1): | |
self.pruner.iterations = 0 | ||
self.pruner.set_wrappers_attribute("contribution", None) | ||
self.pruner.patch_optimizer(self.calc_contributions) | ||
self.global_threshold = None | ||
|
||
def _get_global_threshold(self): | ||
channel_contribution_list = [] | ||
for wrapper_idx, wrapper in enumerate(self.pruner.get_modules_wrapper()): | ||
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) | ||
wrapper_size = wrapper.module.weight.size().numel() | ||
channel_size = wrapper.module.weight.size(0) | ||
contribution_expand = channel_contribution.expand(int(wrapper_size / channel_size), channel_size).reshape(-1) | ||
channel_contribution_list.append(contribution_expand) | ||
all_channel_contributions = torch.cat(channel_contribution_list) | ||
k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if the filters' sizes are different? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's truly a key problem, I will fix it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have fixed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a problem. Have fixed. |
||
self.global_threshold = torch.topk( | ||
all_channel_contributions.view(-1), k, largest=False)[0].max() | ||
|
||
def _get_global_num_prune(self, wrapper, wrapper_idx): | ||
if self.global_threshold is None: | ||
self._get_global_threshold() | ||
weight = wrapper.module.weight.data | ||
filters = weight.size(0) | ||
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) | ||
num_prune = channel_contribution[channel_contribution < self.global_threshold].size()[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a dumb question, do we want < or <= here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a key point and I thought it before, we can choose using '<=' which may cause |
||
if num_prune == filters: | ||
num_prune -= 1 | ||
return num_prune | ||
|
||
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): | ||
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it may be better to implement it for Taylor alone without changing the underlying interface, because we don't have another masker use these interfaces, and
slim
is natural using global sort.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But based our discussion before, we agree that we should modify the condition in
May be we can add
assert
to tell user that we only support Taylor currently?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the current implementation is good for me, just a little opinion, don't need to modify. And we need update doc for Taylor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have updated doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you put global sort in
StructuredWeightMasker
but not inTaylorFOWeightFilterPrunerMasker
. what is the reason for it? is it because this would makeTaylorFOWeightFilterPrunerMasker
have different initial arguments with other structured weight masker?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because in here, we modify the conditional judgement statement which discriminate different situations including global-sort and dependency-aware. The reason why we do conditional judgement here is because we think global-sort should be the same level with dependency-aware. This conditional judgement is done in
StructuredWeightMasker
so we putglobal_sort
here.