Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Constraint-aware one-shot pruners #2657

Merged
merged 111 commits into from
Sep 21, 2020

Conversation

zheng-ningxin
Copy link
Contributor

In this pr, we add three constraint-aware one-shot pruners into NNI: Constrained_L1FilterPruner, Constrained_L2FilterPruner, ConstrainedActivationMeanRankFilterPruner.
These constraint-aware pruners are aware of the constraints of the channel dependency/ group dependency and prunes the model under such constraints, so that we can better harvest the speed benefit from model pruning. In the original version, the L1FilterPruner prunes the model only based on the L1 norm values, and many pruned models violate the aforementioned constraints(channel dependency/ group dependency). Therefore, the benefits of the model pruning cannot be obtained through the speedup module.

Ningxin added 12 commits July 5, 2020 07:05
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
@zheng-ningxin
Copy link
Contributor Author

#2616

Ningxin added 4 commits July 8, 2020 03:23
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Ningxin added 11 commits July 20, 2020 06:57
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
@QuanluZhang QuanluZhang mentioned this pull request Sep 14, 2020
79 tasks
def update_mask(self):
if not self.dependency_aware:
# if we use the normal way to update the mask,
# then call the updata_mask of the father class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updata -> update

def _dependency_update_mask(self):
"""
In the original update_mask, the wraper of each layer will update its
mask own mask according to the sparsity specified in the config_list. However, in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask own mask -> own mask

The list of the wrappers that in the same channel dependency
set.
wrappers_idx : list
The list of the indexes of wrapppers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also write "Returns" here in docstring

@@ -47,6 +79,9 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
layer wrapper of this layer
wrapper_idx: int
index of this wrapper in pruner's all wrappers
channel_masks: Tensor
channel_mask indicates the channels that we should at least mask.
the finnal masked channels should include these channels.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finnal -> final

num_prune = num_total - num_preserve

if num_total < 2 or num_prune < 1:
return mask
# weight*mask_weight: apply base mask for iterative pruning
return self.get_mask(mask, weight*mask_weight, num_prune, wrapper, wrapper_idx)
return self.get_mask(mask, weight*mask_weight, num_prune, wrapper, wrapper_idx, channel_masks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between mask and channel_masks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

channel_masks indicates the output channels that we should at least mask in this conv layer. channel_masks actually are the common channels that all the layers in the dependency group should prune. mask is just the mask of this conv layer, maybe we can change the name to make this clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you can directly remove channel_masks in this call, and also remove it from _normal_calc_mask's argument list

# find the max number of the filter groups of the dependent
# layers. The group constraint of this dependency set is decided
# by the layer with the max groups.
max_group = max(groups)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we should choose the maximum value of the groups?

## Evaluation
In order to compare the performance of the pruner with or without the dependency-aware mode, we use L1FilterPruner to prune the Mobilenet_v2 separately when the dependency-aware mode is turned on and off. To simplify the experiment, we use the uniform pruning which means we allocate the same sparsity for all convolutional layers in the model.
We trained a Mobilenet_v2 model on the cifar10 dataset and prune the model based on this pretrained checkpoint. The following figure shows the accuracy and FLOPs of the model pruned by different pruners.
![](../../img/mobilev2_l1_cifar.jpg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this figure looks great!

Comment on lines 30 to 100
```python
from nni.compression.torch import L1FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
# dummy_input is necessary for the dependency_aware mode
dummy_input = torch.ones(1, 3, 224, 224).cuda()
pruner = L1FilterPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input)
pruner.compress()
```

To enable the dependency-aware mode for `L2FilterPruner`:
```python
from nni.compression.torch import L2FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
# dummy_input is necessary for the dependency_aware mode
dummy_input = torch.ones(1, 3, 224, 224).cuda()
pruner = L2FilterPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input)
pruner.compress()
```

To enable the dependency-aware mode for `FPGMPruner`:
```python
from nni.compression.torch import FPGMPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
# dummy_input is necessary for the dependency_aware mode
dummy_input = torch.ones(1, 3, 224, 224).cuda()
pruner = FPGMPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input)
pruner.compress()
```

To enable the dependency-aware mode for `ActivationAPoZRankFilterPruner`
```python
from nni.compression.torch import ActivationAPoZRankFilterPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
# dummy_input is necessary for the dependency_aware mode
dummy_input = torch.ones(1, 3, 224, 224).cuda()
pruner = ActivationAPoZRankFilterPruner(model, config_list, statistics_batch_num=1, , dependency_aware=True, dummy_input=dummy_input)
pruner.compress()
```

To enable the dependency-aware mode for `ActivationMeanRankFilterPruner`:

```python
from nni.compression.torch import ActivationMeanRankFilterPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
# dummy_input is necessary for the dependency-aware mode and the
# dummy_input should be on the same device with the model
dummy_input = torch.ones(1, 3, 224, 224).cuda()
pruner = ActivationMeanRankFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input)
pruner.compress()
```

To enable the dependency-aware mode for `TaylorFOWeightFilterPruner`:
```python
from nni.compression.torch import TaylorFOWeightFilterPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
dummy_input = torch.ones(1, 3, 224, 224).cuda()
pruner = TaylorFOWeightFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input)
pruner.compress()
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can simplify this example code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, how about we just keep the example of L1FilterPruner and remove the others?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, and above this line pruner = L1FilterPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input) you can use comments to show the usage of other pruners, one comment line for each pruner

threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None
return w_abs_structured
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we only support output channel, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the L1FilterPruner only prunes the filters (output channel).

@QuanluZhang
Copy link
Contributor

QuanluZhang commented Sep 18, 2020

@chicm-ms , I didn't fully check the modifications of XXXPrunerMasker, as I know little about them.

return None
mean_activation = self._cal_mean_activation(activations)
if channel_masks is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this special handing for mean activation as discussed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated~

@chicm-ms chicm-ms merged commit ec5af41 into microsoft:master Sep 21, 2020
zheng-ningxin added a commit to zheng-ningxin/nni that referenced this pull request Nov 18, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants