This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Fix pruners #2153
Merged
Merged
Fix pruners #2153
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
3a45961
Merge pull request #31 from microsoft/master
chicm-ms 633db43
Merge pull request #32 from microsoft/master
chicm-ms 3e926f1
Merge pull request #33 from microsoft/master
chicm-ms f173789
Merge pull request #34 from microsoft/master
chicm-ms 508850a
Merge pull request #35 from microsoft/master
chicm-ms 5a0e9c9
Merge pull request #36 from microsoft/master
chicm-ms e7df061
Merge pull request #37 from microsoft/master
chicm-ms 2175cef
Merge pull request #38 from microsoft/master
chicm-ms 2ccbfbb
Merge pull request #39 from microsoft/master
chicm-ms b29cb0b
Merge pull request #40 from microsoft/master
chicm-ms 4a3ba83
Merge pull request #41 from microsoft/master
chicm-ms c8a1148
Merge pull request #42 from microsoft/master
chicm-ms 73c6101
Merge pull request #43 from microsoft/master
chicm-ms 6a518a9
Merge pull request #44 from microsoft/master
chicm-ms a0d587f
Merge pull request #45 from microsoft/master
chicm-ms e905bfe
Merge pull request #46 from microsoft/master
chicm-ms 4b266f3
Merge pull request #47 from microsoft/master
chicm-ms 237ff4b
Merge pull request #48 from microsoft/master
chicm-ms 682be01
Merge pull request #49 from microsoft/master
chicm-ms 133af82
Merge pull request #50 from microsoft/master
chicm-ms 71a8a25
Merge pull request #51 from microsoft/master
chicm-ms d2a73bc
Merge pull request #52 from microsoft/master
chicm-ms 198cf5e
Merge pull request #53 from microsoft/master
chicm-ms cdbfaf9
Merge pull request #54 from microsoft/master
chicm-ms 7e9b29e
Merge pull request #55 from microsoft/master
chicm-ms d00c46d
Merge pull request #56 from microsoft/master
chicm-ms de7d1fa
Merge pull request #57 from microsoft/master
chicm-ms 1835ab0
Merge pull request #58 from microsoft/master
chicm-ms 24fead6
Merge pull request #59 from microsoft/master
chicm-ms 0b7321e
Merge pull request #60 from microsoft/master
chicm-ms 60058d4
Merge pull request #61 from microsoft/master
chicm-ms b111a55
Merge pull request #62 from microsoft/master
chicm-ms 611c337
Merge pull request #63 from microsoft/master
chicm-ms 4a1f14a
Merge pull request #64 from microsoft/master
chicm-ms 7a9e604
Merge pull request #65 from microsoft/master
chicm-ms b8035b0
Merge pull request #66 from microsoft/master
chicm-ms 47567d3
Merge pull request #67 from microsoft/master
chicm-ms 614d427
Merge pull request #68 from microsoft/master
chicm-ms a0d9ed6
Merge pull request #69 from microsoft/master
chicm-ms 22dc1ad
Merge pull request #70 from microsoft/master
chicm-ms 0856813
Merge pull request #71 from microsoft/master
chicm-ms 9e97bed
Merge pull request #72 from microsoft/master
chicm-ms 16a1b27
Merge pull request #73 from microsoft/master
chicm-ms e246633
Merge pull request #74 from microsoft/master
chicm-ms 477520b
Add pruner UT
chicm-ms 78d041f
updates
chicm-ms f615a06
updates
chicm-ms 9300d4c
updates
chicm-ms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import os | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import math | ||
from unittest import TestCase, main | ||
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \ | ||
L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner | ||
|
||
def validate_sparsity(wrapper, sparsity, bias=False): | ||
masks = [wrapper.weight_mask] | ||
if bias and wrapper.bias_mask is not None: | ||
masks.append(wrapper.bias_mask) | ||
for m in masks: | ||
actual_sparsity = (m == 0).sum().item() / m.numel() | ||
msg = 'actual sparsity: {:.2f}, target sparsity: {:.2f}'.format(actual_sparsity, sparsity) | ||
assert math.isclose(actual_sparsity, sparsity, abs_tol=0.1), msg | ||
|
||
prune_config = { | ||
'level': { | ||
'pruner_class': LevelPruner, | ||
'config_list': [{ | ||
'sparsity': 0.5, | ||
'op_types': ['default'], | ||
}], | ||
'validators': [ | ||
lambda model: validate_sparsity(model.conv1, 0.5, False), | ||
lambda model: validate_sparsity(model.fc, 0.5, False) | ||
] | ||
}, | ||
'agp': { | ||
'pruner_class': AGP_Pruner, | ||
'config_list': [{ | ||
'initial_sparsity': 0, | ||
'final_sparsity': 0.8, | ||
'start_epoch': 0, | ||
'end_epoch': 10, | ||
'frequency': 1, | ||
'op_types': ['default'] | ||
}], | ||
'validators': [] | ||
}, | ||
'slim': { | ||
'pruner_class': SlimPruner, | ||
'config_list': [{ | ||
'sparsity': 0.7, | ||
'op_types': ['BatchNorm2d'] | ||
}], | ||
'validators': [ | ||
lambda model: validate_sparsity(model.bn1, 0.7, model.bias) | ||
] | ||
}, | ||
'fpgm': { | ||
'pruner_class': FPGMPruner, | ||
'config_list':[{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'] | ||
}], | ||
'validators': [ | ||
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) | ||
] | ||
}, | ||
'l1': { | ||
'pruner_class': L1FilterPruner, | ||
'config_list': [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'], | ||
}], | ||
'validators': [ | ||
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) | ||
] | ||
}, | ||
'l2': { | ||
'pruner_class': L2FilterPruner, | ||
'config_list': [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'], | ||
}], | ||
'validators': [ | ||
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) | ||
] | ||
}, | ||
'mean_activation': { | ||
'pruner_class': ActivationMeanRankFilterPruner, | ||
'config_list': [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'], | ||
}], | ||
'validators': [ | ||
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) | ||
] | ||
}, | ||
'apoz': { | ||
'pruner_class': ActivationAPoZRankFilterPruner, | ||
'config_list': [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'], | ||
}], | ||
'validators': [ | ||
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) | ||
] | ||
} | ||
} | ||
|
||
class Model(nn.Module): | ||
def __init__(self, bias=True): | ||
super(Model, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1, bias=bias) | ||
self.bn1 = nn.BatchNorm2d(8) | ||
self.pool = nn.AdaptiveAvgPool2d(1) | ||
self.fc = nn.Linear(8, 2, bias=bias) | ||
self.bias = bias | ||
def forward(self, x): | ||
return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1)) | ||
|
||
def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'mean_activation', 'apoz'], bias=True): | ||
for pruner_name in pruner_names: | ||
print('testing {}...'.format(pruner_name)) | ||
model = Model(bias=bias) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||
config_list = prune_config[pruner_name]['config_list'] | ||
|
||
x = torch.randn(2, 1, 28, 28) | ||
y = torch.tensor([0, 1]).long() | ||
out = model(x) | ||
loss = F.cross_entropy(out, y) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer) | ||
pruner.compress() | ||
|
||
x = torch.randn(2, 1, 28, 28) | ||
y = torch.tensor([0, 1]).long() | ||
out = model(x) | ||
loss = F.cross_entropy(out, y) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28)) | ||
|
||
for v in prune_config[pruner_name]['validators']: | ||
v(model) | ||
|
||
os.remove('./model_tmp.pth') | ||
os.remove('./mask_tmp.pth') | ||
os.remove('./onnx_tmp.pth') | ||
|
||
class PrunerTestCase(TestCase): | ||
def test_pruners(self): | ||
pruners_test(bias=True) | ||
|
||
def test_pruners_no_bias(self): | ||
pruners_test(bias=False) | ||
|
||
if __name__ == '__main__': | ||
main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, for iterative pruner, we should think about how to test them.