Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix pruners #2153

Merged
merged 48 commits into from
Mar 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
3a45961
Merge pull request #31 from microsoft/master
chicm-ms Aug 6, 2019
633db43
Merge pull request #32 from microsoft/master
chicm-ms Sep 9, 2019
3e926f1
Merge pull request #33 from microsoft/master
chicm-ms Oct 8, 2019
f173789
Merge pull request #34 from microsoft/master
chicm-ms Oct 9, 2019
508850a
Merge pull request #35 from microsoft/master
chicm-ms Oct 9, 2019
5a0e9c9
Merge pull request #36 from microsoft/master
chicm-ms Oct 10, 2019
e7df061
Merge pull request #37 from microsoft/master
chicm-ms Oct 23, 2019
2175cef
Merge pull request #38 from microsoft/master
chicm-ms Oct 29, 2019
2ccbfbb
Merge pull request #39 from microsoft/master
chicm-ms Oct 30, 2019
b29cb0b
Merge pull request #40 from microsoft/master
chicm-ms Oct 30, 2019
4a3ba83
Merge pull request #41 from microsoft/master
chicm-ms Nov 4, 2019
c8a1148
Merge pull request #42 from microsoft/master
chicm-ms Nov 4, 2019
73c6101
Merge pull request #43 from microsoft/master
chicm-ms Nov 5, 2019
6a518a9
Merge pull request #44 from microsoft/master
chicm-ms Nov 11, 2019
a0d587f
Merge pull request #45 from microsoft/master
chicm-ms Nov 12, 2019
e905bfe
Merge pull request #46 from microsoft/master
chicm-ms Nov 14, 2019
4b266f3
Merge pull request #47 from microsoft/master
chicm-ms Nov 15, 2019
237ff4b
Merge pull request #48 from microsoft/master
chicm-ms Nov 21, 2019
682be01
Merge pull request #49 from microsoft/master
chicm-ms Nov 25, 2019
133af82
Merge pull request #50 from microsoft/master
chicm-ms Nov 25, 2019
71a8a25
Merge pull request #51 from microsoft/master
chicm-ms Nov 26, 2019
d2a73bc
Merge pull request #52 from microsoft/master
chicm-ms Nov 26, 2019
198cf5e
Merge pull request #53 from microsoft/master
chicm-ms Dec 5, 2019
cdbfaf9
Merge pull request #54 from microsoft/master
chicm-ms Dec 6, 2019
7e9b29e
Merge pull request #55 from microsoft/master
chicm-ms Dec 10, 2019
d00c46d
Merge pull request #56 from microsoft/master
chicm-ms Dec 10, 2019
de7d1fa
Merge pull request #57 from microsoft/master
chicm-ms Dec 11, 2019
1835ab0
Merge pull request #58 from microsoft/master
chicm-ms Dec 12, 2019
24fead6
Merge pull request #59 from microsoft/master
chicm-ms Dec 20, 2019
0b7321e
Merge pull request #60 from microsoft/master
chicm-ms Dec 23, 2019
60058d4
Merge pull request #61 from microsoft/master
chicm-ms Dec 23, 2019
b111a55
Merge pull request #62 from microsoft/master
chicm-ms Dec 24, 2019
611c337
Merge pull request #63 from microsoft/master
chicm-ms Dec 30, 2019
4a1f14a
Merge pull request #64 from microsoft/master
chicm-ms Jan 10, 2020
7a9e604
Merge pull request #65 from microsoft/master
chicm-ms Jan 14, 2020
b8035b0
Merge pull request #66 from microsoft/master
chicm-ms Feb 4, 2020
47567d3
Merge pull request #67 from microsoft/master
chicm-ms Feb 10, 2020
614d427
Merge pull request #68 from microsoft/master
chicm-ms Feb 10, 2020
a0d9ed6
Merge pull request #69 from microsoft/master
chicm-ms Feb 11, 2020
22dc1ad
Merge pull request #70 from microsoft/master
chicm-ms Feb 19, 2020
0856813
Merge pull request #71 from microsoft/master
chicm-ms Feb 22, 2020
9e97bed
Merge pull request #72 from microsoft/master
chicm-ms Feb 25, 2020
16a1b27
Merge pull request #73 from microsoft/master
chicm-ms Mar 3, 2020
e246633
Merge pull request #74 from microsoft/master
chicm-ms Mar 4, 2020
477520b
Add pruner UT
chicm-ms Mar 16, 2020
78d041f
updates
chicm-ms Mar 16, 2020
f615a06
updates
chicm-ms Mar 16, 2020
9300d4c
updates
chicm-ms Mar 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def get_mask(self, base_mask, weight, num_prune):
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None

return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}


class L2FilterPruner(WeightRankFilterPruner):
Expand Down Expand Up @@ -166,9 +166,9 @@ def get_mask(self, base_mask, weight, num_prune):
w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None

return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}


class FPGMPruner(WeightRankFilterPruner):
Expand Down
162 changes: 162 additions & 0 deletions src/sdk/pynni/tests/test_pruners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from unittest import TestCase, main
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner

def validate_sparsity(wrapper, sparsity, bias=False):
masks = [wrapper.weight_mask]
if bias and wrapper.bias_mask is not None:
masks.append(wrapper.bias_mask)
for m in masks:
actual_sparsity = (m == 0).sum().item() / m.numel()
msg = 'actual sparsity: {:.2f}, target sparsity: {:.2f}'.format(actual_sparsity, sparsity)
assert math.isclose(actual_sparsity, sparsity, abs_tol=0.1), msg

prune_config = {
'level': {
'pruner_class': LevelPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['default'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, False),
lambda model: validate_sparsity(model.fc, 0.5, False)
]
},
'agp': {
'pruner_class': AGP_Pruner,
'config_list': [{
'initial_sparsity': 0,
'final_sparsity': 0.8,
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': ['default']
}],
'validators': []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for iterative pruner, we should think about how to test them.

},
'slim': {
'pruner_class': SlimPruner,
'config_list': [{
'sparsity': 0.7,
'op_types': ['BatchNorm2d']
}],
'validators': [
lambda model: validate_sparsity(model.bn1, 0.7, model.bias)
]
},
'fpgm': {
'pruner_class': FPGMPruner,
'config_list':[{
'sparsity': 0.5,
'op_types': ['Conv2d']
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'l1': {
'pruner_class': L1FilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'l2': {
'pruner_class': L2FilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'mean_activation': {
'pruner_class': ActivationMeanRankFilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'apoz': {
'pruner_class': ActivationAPoZRankFilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
}
}

class Model(nn.Module):
def __init__(self, bias=True):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1, bias=bias)
self.bn1 = nn.BatchNorm2d(8)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 2, bias=bias)
self.bias = bias
def forward(self, x):
return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1))

def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'mean_activation', 'apoz'], bias=True):
for pruner_name in pruner_names:
print('testing {}...'.format(pruner_name))
model = Model(bias=bias)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
config_list = prune_config[pruner_name]['config_list']

x = torch.randn(2, 1, 28, 28)
y = torch.tensor([0, 1]).long()
out = model(x)
loss = F.cross_entropy(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer)
pruner.compress()

x = torch.randn(2, 1, 28, 28)
y = torch.tensor([0, 1]).long()
out = model(x)
loss = F.cross_entropy(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28))

for v in prune_config[pruner_name]['validators']:
v(model)

os.remove('./model_tmp.pth')
os.remove('./mask_tmp.pth')
os.remove('./onnx_tmp.pth')

class PrunerTestCase(TestCase):
def test_pruners(self):
pruners_test(bias=True)

def test_pruners_no_bias(self):
pruners_test(bias=False)

if __name__ == '__main__':
main()