Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

add pruner unit test #1771

Merged
merged 3 commits into from
Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en_US/Compressor/SlimPruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ We implemented one of the experiments in ['Learning Efficient Convolutional Netw
| Model | Error(paper/ours) | Parameters | Pruned |
| ------------- | ----------------- | ---------- | --------- |
| VGGNet | 6.34/6.40 | 20.04M | |
| Pruned-VGGNet | 6.20/6.39 | 2.03M | 88.5% |
| Pruned-VGGNet | 6.20/6.26 | 2.03M | 88.5% |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you provide some settings in those experiments? like hyper-parameters.


The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/)
2 changes: 1 addition & 1 deletion examples/model_compress/slim_pruner_torch_vgg19.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def main():
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
test(new_model, device, test_loader)
# top1 = 93.61%
# top1 = 93.74%
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this value will change by some random seed? If will, we should not assume the top1 value is a specific value.



if __name__ == '__main__':
Expand Down
10 changes: 5 additions & 5 deletions src/sdk/pynni/nni/compression/torch/builtin_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def calc_mask(self, layer, config):
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: mask})
self.if_init_list.update({op_name: False})
Expand Down Expand Up @@ -108,7 +108,7 @@ def calc_mask(self, layer, config):
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
Expand Down Expand Up @@ -336,7 +336,7 @@ def calc_mask(self, layer, config):
if k == 0:
return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False).values.max()
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
Expand Down Expand Up @@ -370,10 +370,10 @@ def __init__(self, model, config_list):
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.clone())
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False).values.max()
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()

def calc_mask(self, layer, config):
"""
Expand Down
95 changes: 85 additions & 10 deletions src/sdk/pynni/tests/test_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
if tf.__version__ >= '2.0':
import nni.compression.tensorflow as tf_compressor


def get_tf_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
Expand All @@ -20,42 +21,49 @@ def get_tf_model():
tf.keras.layers.Dense(units=10, activation='softmax'),
])
model.compile(loss="sparse_categorical_crossentropy",
optimizer=tf.keras.optimizers.SGD(lr=1e-3),
metrics=["accuracy"])
optimizer=tf.keras.optimizers.SGD(lr=1e-3),
metrics=["accuracy"])
return model


class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def tf2(func):
def test_tf2_func(*args):
if tf.__version__ >= '2.0':
func(*args)

return test_tf2_func

k1 = [[1]*3]*3
k2 = [[2]*3]*3
k3 = [[3]*3]*3
k4 = [[4]*3]*3
k5 = [[5]*3]*3

k1 = [[1] * 3] * 3
k2 = [[2] * 3] * 3
k3 = [[3] * 3] * 3
k4 = [[4] * 3] * 3
k5 = [[5] * 3] * 3

w = [[k1, k2, k3, k4, k5]] * 10


class CompressorTestCase(TestCase):
def test_torch_level_pruner(self):
model = TorchModel()
Expand All @@ -74,7 +82,7 @@ def test_torch_naive_quantizer(self):
'quant_bits': {
'weight': 8,
},
'op_types':['Conv2d', 'Linear']
'op_types': ['Conv2d', 'Linear']
}]
torch_compressor.NaiveQuantizer(model, configure_list).compress()

Expand Down Expand Up @@ -133,6 +141,73 @@ def test_tf_fpgm_pruner(self):

assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.]))

def test_torch_l1filter_pruner(self):
"""
Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
PRUNING FILTERS FOR EFFICIENT CONVNETS,
https://arxiv.org/abs/1608.08710

So if sparsity is 0.2, the expected masks should mask out filter 0, this can be verified through:
`all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))`

If sparsity is 0.6, the expected masks should mask out filter 0,1,2, this can be verified through:
`all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))`
"""
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_names': ['conv1']}, {'sparsity': 0.6, 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list)

model.conv1.weight.data = torch.tensor(w).float()
model.conv2.weight.data = torch.tensor(w).float()
layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1])
assert all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))

def test_torch_slim_pruner(self):
"""
Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
Learning Efficient Convolutional Networks through Network Slimming,
https://arxiv.org/pdf/1708.06519.pdf

So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
`all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`

If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
`all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
"""
w = np.array([0, 1, 2, 3, 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(-w).float()
pruner = torch_compressor.SlimPruner(model, config_list)

layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))

config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(w).float()
pruner = torch_compressor.SlimPruner(model, config_list)

layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))


if __name__ == '__main__':
main()