-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Model Compression] Add global sort for taylor pruner #3896
Changes from 6 commits
fea3768
23998a9
74e7731
763923e
be5f78c
8a75737
a5e600d
959422a
e0010b2
4fccc26
dea4849
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ class IterativePruner(DependencyAwarePruner): | |
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', trainer=None, criterion=None, | ||
num_iterations=20, epochs_per_iteration=5, dependency_aware=False, dummy_input=None, **algo_kwargs): | ||
num_iterations=20, epochs_per_iteration=5, dependency_aware=False, dummy_input=None, global_sort=False, **algo_kwargs): | ||
""" | ||
Parameters | ||
---------- | ||
|
@@ -51,6 +51,9 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', | |
dummy_input: torch.Tensor | ||
The dummy input to analyze the topology constraints. Note that, | ||
the dummy_input should on the same device with the model. | ||
global_sort: bool | ||
If prune the model in a global-sort way. | ||
Only support TaylorFOWeightFilterPruner currently. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
algo_kwargs: dict | ||
Additional parameters passed to pruning algorithm masker class | ||
""" | ||
|
@@ -486,10 +489,15 @@ class TaylorFOWeightFilterPruner(IterativePruner): | |
dummy_input : torch.Tensor | ||
The dummy input to analyze the topology constraints. Note that, the dummy_input | ||
should on the same device with the model. | ||
global_sort: bool | ||
Only support TaylorFOWeightFilterPruner currently. | ||
If prune the model in a global-sort way. If it is `True`, this pruner will prune | ||
the model according to the global contributions information which means channel contributions | ||
will be sorted globally and whether specific channel will be pruned depends on global information. | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, | ||
dependency_aware=False, dummy_input=None): | ||
dependency_aware=False, dummy_input=None, global_sort=False): | ||
super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer, | ||
criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1, | ||
epochs_per_iteration=1, dependency_aware=dependency_aware, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,11 +33,12 @@ class StructuredWeightMasker(WeightMasker): | |
|
||
""" | ||
|
||
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False): | ||
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False, global_sort=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it may be better to implement it for Taylor alone without changing the underlying interface, because we don't have another masker use these interfaces, and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But based our discussion before, we agree that we should modify the condition in May be we can add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the current implementation is good for me, just a little opinion, don't need to modify. And we need update doc for Taylor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have updated doc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so you put global sort in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because in here, we modify the conditional judgement statement which discriminate different situations including global-sort and dependency-aware. The reason why we do conditional judgement here is because we think global-sort should be the same level with dependency-aware. This conditional judgement is done in |
||
self.model = model | ||
self.pruner = pruner | ||
self.preserve_round = preserve_round | ||
self.dependency_aware = dependency_aware | ||
self.global_sort = global_sort | ||
|
||
def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs): | ||
""" | ||
|
@@ -60,7 +61,11 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs): | |
depen_kwargs: dict | ||
The kw_args for the dependency-aware mode. | ||
""" | ||
if not self.dependency_aware: | ||
if self.global_sort: | ||
# if the global_sort switch is on, calculate the mask based | ||
# on global model information | ||
return self._global_calc_mask(sparsity, wrapper, wrapper_idx) | ||
elif not self.dependency_aware: | ||
# calculate the mask in the normal way, each layer calculate its | ||
# own mask separately | ||
return self._normal_calc_mask(sparsity, wrapper, wrapper_idx) | ||
|
@@ -127,6 +132,12 @@ def _get_current_state(self, sparsity, wrapper, wrapper_idx=None): | |
# weight*mask_weight: apply base mask for iterative pruning | ||
return mask, weight * mask_weight, num_prune | ||
|
||
def _global_calc_mask(self, sparsity, wrapper, wrapper_idx=None): | ||
num_prune = self._get_global_num_prune(wrapper, wrapper_idx) | ||
mask, weight, _ = self._get_current_state( | ||
sparsity, wrapper, wrapper_idx) | ||
return self.get_mask(mask, weight, num_prune, wrapper, wrapper_idx) | ||
|
||
def _normal_calc_mask(self, sparsity, wrapper, wrapper_idx=None): | ||
""" | ||
Calculate the mask of given layer. | ||
|
@@ -477,6 +488,29 @@ def __init__(self, model, pruner, statistics_batch_num=1): | |
self.pruner.iterations = 0 | ||
self.pruner.set_wrappers_attribute("contribution", None) | ||
self.pruner.patch_optimizer(self.calc_contributions) | ||
self.global_threshold = None | ||
|
||
def _get_global_threshold(self): | ||
channel_contribution_list = [] | ||
for wrapper_idx, wrapper in enumerate(self.pruner.get_modules_wrapper()): | ||
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) | ||
channel_contribution_list.append(channel_contribution) | ||
all_channel_contributions = torch.cat(channel_contribution_list) | ||
k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if the filters' sizes are different? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's truly a key problem, I will fix it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have fixed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a problem. Have fixed. |
||
self.global_threshold = torch.topk( | ||
all_channel_contributions.view(-1), k, largest=False)[0].max() | ||
print(f'set global threshold to {self.global_threshold}') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better to remove this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have removed. |
||
|
||
def _get_global_num_prune(self, wrapper, wrapper_idx): | ||
if self.global_threshold is None: | ||
self._get_global_threshold() | ||
weight = wrapper.module.weight.data | ||
filters = weight.size(0) | ||
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) | ||
num_prune = channel_contribution[channel_contribution < self.global_threshold].size()[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a dumb question, do we want < or <= here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a key point and I thought it before, we can choose using '<=' which may cause |
||
if num_prune == filters: | ||
num_prune -= 1 | ||
return num_prune | ||
|
||
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): | ||
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do all the dependency pruners support
global_sort
mode? If not, I'm a little concerned.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, how many pruners support
global_sort
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all of them support
global_sort
, have deleted parameters in them.