-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Conversation
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
def update_mask(self): | ||
if not self.dependency_aware: | ||
# if we use the normal way to update the mask, | ||
# then call the updata_mask of the father class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updata -> update
def _dependency_update_mask(self): | ||
""" | ||
In the original update_mask, the wraper of each layer will update its | ||
mask own mask according to the sparsity specified in the config_list. However, in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask own mask -> own mask
The list of the wrappers that in the same channel dependency | ||
set. | ||
wrappers_idx : list | ||
The list of the indexes of wrapppers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please also write "Returns" here in docstring
@@ -47,6 +79,9 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None): | |||
layer wrapper of this layer | |||
wrapper_idx: int | |||
index of this wrapper in pruner's all wrappers | |||
channel_masks: Tensor | |||
channel_mask indicates the channels that we should at least mask. | |||
the finnal masked channels should include these channels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
finnal -> final
num_prune = num_total - num_preserve | ||
|
||
if num_total < 2 or num_prune < 1: | ||
return mask | ||
# weight*mask_weight: apply base mask for iterative pruning | ||
return self.get_mask(mask, weight*mask_weight, num_prune, wrapper, wrapper_idx) | ||
return self.get_mask(mask, weight*mask_weight, num_prune, wrapper, wrapper_idx, channel_masks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the difference between mask
and channel_masks
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
channel_masks
indicates the output channels that we should at least mask in this conv layer. channel_masks
actually are the common channels that all the layers in the dependency group should prune. mask
is just the mask of this conv layer, maybe we can change the name to make this clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you can directly remove channel_masks in this call, and also remove it from _normal_calc_mask
's argument list
# find the max number of the filter groups of the dependent | ||
# layers. The group constraint of this dependency set is decided | ||
# by the layer with the max groups. | ||
max_group = max(groups) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we should choose the maximum value of the groups?
## Evaluation | ||
In order to compare the performance of the pruner with or without the dependency-aware mode, we use L1FilterPruner to prune the Mobilenet_v2 separately when the dependency-aware mode is turned on and off. To simplify the experiment, we use the uniform pruning which means we allocate the same sparsity for all convolutional layers in the model. | ||
We trained a Mobilenet_v2 model on the cifar10 dataset and prune the model based on this pretrained checkpoint. The following figure shows the accuracy and FLOPs of the model pruned by different pruners. | ||
![](../../img/mobilev2_l1_cifar.jpg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this figure looks great!
```python | ||
from nni.compression.torch import L1FilterPruner | ||
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }] | ||
# dummy_input is necessary for the dependency_aware mode | ||
dummy_input = torch.ones(1, 3, 224, 224).cuda() | ||
pruner = L1FilterPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input) | ||
pruner.compress() | ||
``` | ||
|
||
To enable the dependency-aware mode for `L2FilterPruner`: | ||
```python | ||
from nni.compression.torch import L2FilterPruner | ||
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }] | ||
# dummy_input is necessary for the dependency_aware mode | ||
dummy_input = torch.ones(1, 3, 224, 224).cuda() | ||
pruner = L2FilterPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input) | ||
pruner.compress() | ||
``` | ||
|
||
To enable the dependency-aware mode for `FPGMPruner`: | ||
```python | ||
from nni.compression.torch import FPGMPruner | ||
config_list = [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'] | ||
}] | ||
# dummy_input is necessary for the dependency_aware mode | ||
dummy_input = torch.ones(1, 3, 224, 224).cuda() | ||
pruner = FPGMPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input) | ||
pruner.compress() | ||
``` | ||
|
||
To enable the dependency-aware mode for `ActivationAPoZRankFilterPruner` | ||
```python | ||
from nni.compression.torch import ActivationAPoZRankFilterPruner | ||
config_list = [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'] | ||
}] | ||
# dummy_input is necessary for the dependency_aware mode | ||
dummy_input = torch.ones(1, 3, 224, 224).cuda() | ||
pruner = ActivationAPoZRankFilterPruner(model, config_list, statistics_batch_num=1, , dependency_aware=True, dummy_input=dummy_input) | ||
pruner.compress() | ||
``` | ||
|
||
To enable the dependency-aware mode for `ActivationMeanRankFilterPruner`: | ||
|
||
```python | ||
from nni.compression.torch import ActivationMeanRankFilterPruner | ||
config_list = [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'] | ||
}] | ||
# dummy_input is necessary for the dependency-aware mode and the | ||
# dummy_input should be on the same device with the model | ||
dummy_input = torch.ones(1, 3, 224, 224).cuda() | ||
pruner = ActivationMeanRankFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input) | ||
pruner.compress() | ||
``` | ||
|
||
To enable the dependency-aware mode for `TaylorFOWeightFilterPruner`: | ||
```python | ||
from nni.compression.torch import TaylorFOWeightFilterPruner | ||
config_list = [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'] | ||
}] | ||
dummy_input = torch.ones(1, 3, 224, 224).cuda() | ||
pruner = TaylorFOWeightFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input) | ||
pruner.compress() | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can simplify this example code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, how about we just keep the example of L1FilterPruner
and remove the others?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, and above this line pruner = L1FilterPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input)
you can use comments to show the usage of other pruners, one comment line for each pruner
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max() | ||
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight) | ||
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None | ||
return w_abs_structured |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here we only support output channel, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the L1FilterPruner
only prunes the filters (output channel).
@chicm-ms , I didn't fully check the modifications of |
return None | ||
mean_activation = self._cal_mean_activation(activations) | ||
if channel_masks is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this special handing for mean activation as discussed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated~
In this pr, we add three constraint-aware one-shot pruners into NNI: Constrained_L1FilterPruner, Constrained_L2FilterPruner, ConstrainedActivationMeanRankFilterPruner.
These constraint-aware pruners are aware of the constraints of the channel dependency/ group dependency and prunes the model under such constraints, so that we can better harvest the speed benefit from model pruning. In the original version, the L1FilterPruner prunes the model only based on the L1 norm values, and many pruned models violate the aforementioned constraints(channel dependency/ group dependency). Therefore, the benefits of the model pruning cannot be obtained through the speedup module.