-
Notifications
You must be signed in to change notification settings - Fork 1.8k
fix pruner bugs and add model compression README #1624
Changes from 5 commits
a116a81
fb73513
6805816
0036534
ee7ab1d
98a24fb
1574421
3b14db4
4158105
3f0d034
17f992e
5fdccee
536aefe
ed553a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Run model compression examples | ||
|
||
You can run these examples easily like this, take torch pruning for example | ||
|
||
```bash | ||
python main_torch_pruner.py | ||
``` | ||
|
||
Model compression can be configured in 2 ways | ||
|
||
- By reading ```configure_example.yaml```, this can make codes clean when your configuration is complicated | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. codes -> code |
||
- Directly config in your codes | ||
|
||
In our example, we simply config model compression in our codes like this | ||
|
||
```python | ||
configure_list = [{ | ||
'initial_sparsity': 0, | ||
'final_sparsity': 0.8, | ||
'start_epoch': 1, | ||
'end_epoch': 11, | ||
'frequency': 1, | ||
'op_type': 'default' | ||
}] | ||
pruner = AGP_Pruner(configure_list) | ||
``` | ||
|
||
Please notice that when ```pruner(model)``` called, our model compression codes will be **automatically injected** and you can fine-tune your model **without** any modifications, masked weights **won't** be updated any more during fine tuning. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When |
||
|
||
```python | ||
for epoch in range(10): | ||
print('# Epoch {} #'.format(epoch)) | ||
train(model, device, train_loader, optimizer) | ||
test(model, device, test_loader) | ||
pruner.update_epoch(epoch + 1) | ||
``` | ||
|
||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
When fine tuning finished, pruned weights are all masked and you can get masks like this | ||
|
||
``` | ||
masks = pruner.mask_list | ||
layer_name = xxx | ||
mask = masks[layer_name] | ||
``` | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ def forward(self, x): | |
x = x.view(-1, 4 * 4 * 50) | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim = 1) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
def train(model, device, train_loader, optimizer): | ||
|
@@ -35,6 +35,7 @@ def train(model, device, train_loader, optimizer): | |
if batch_idx % 100 == 0: | ||
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) | ||
|
||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
|
@@ -43,52 +44,53 @@ def test(model, device, test_loader): | |
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction = 'sum').item() | ||
pred = output.argmax(dim = 1, keepdim = True) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
test_loss /= len(test_loader.dataset) | ||
|
||
print('Loss: {} Accuracy: {}%)\n'.format( | ||
test_loss, 100 * correct / len(test_loader.dataset))) | ||
|
||
|
||
def main(): | ||
torch.manual_seed(0) | ||
device = torch.device('cpu') | ||
|
||
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
train_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train = True, download = True, transform = trans), | ||
batch_size = 64, shuffle = True) | ||
datasets.MNIST('data', train=True, download=True, transform=trans), | ||
batch_size=64, shuffle=True) | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train = False, transform = trans), | ||
batch_size = 1000, shuffle = True) | ||
datasets.MNIST('data', train=False, transform=trans), | ||
batch_size=1000, shuffle=True) | ||
|
||
model = Mnist() | ||
|
||
'''you can change this to SensitivityPruner to implement it | ||
pruner = SensitivityPruner(configure_list) | ||
''' | ||
configure_list = [{ | ||
'initial_sparsity': 0, | ||
'final_sparsity': 0.8, | ||
'start_epoch': 1, | ||
'end_epoch': 10, | ||
'frequency': 1, | ||
'op_type': 'default' | ||
}] | ||
'initial_sparsity': 0, | ||
'final_sparsity': 0.8, | ||
'start_epoch': 1, | ||
'end_epoch': 11, | ||
'frequency': 1, | ||
'op_type': 'default' | ||
}] | ||
|
||
pruner = AGP_Pruner(configure_list) | ||
pruner(model) | ||
# you can also use compress(model) method | ||
# like that pruner.compress(model) | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.5) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) | ||
for epoch in range(10): | ||
print('# Epoch {} #'.format(epoch)) | ||
train(model, device, train_loader, optimizer) | ||
test(model, device, test_loader) | ||
pruner.update_epoch(epoch) | ||
|
||
pruner.update_epoch(epoch + 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's discuss whether epoch should start from 0 or 1 next time. |
||
|
||
|
||
if __name__ == '__main__': | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import torch | ||
from .compressor import Pruner | ||
|
||
__all__ = ['LevelPruner', 'AGP_Pruner', 'SensitivityPruner'] | ||
__all__ = ['LevelPruner', 'AGP_Pruner'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you removed sensitivitypruner, so please update doc accordingly |
||
|
||
logger = logging.getLogger('torch pruner') | ||
|
||
|
@@ -17,14 +17,22 @@ def __init__(self, config_list): | |
- sparsity | ||
""" | ||
super().__init__(config_list) | ||
self.mask_list = {} | ||
self.if_init_list = {} | ||
|
||
def calc_mask(self, weight, config, **kwargs): | ||
w_abs = weight.abs() | ||
k = int(weight.numel() * config['sparsity']) | ||
if k == 0: | ||
return torch.ones(weight.shape).type_as(weight) | ||
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() | ||
return torch.gt(w_abs, threshold).type_as(weight) | ||
def calc_mask(self, weight, config, op_name, **kwargs): | ||
if self.if_init_list.get(op_name, True): | ||
w_abs = weight.abs() | ||
k = int(weight.numel() * config['sparsity']) | ||
if k == 0: | ||
return torch.ones(weight.shape).type_as(weight) | ||
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() | ||
mask = torch.gt(w_abs, threshold).type_as(weight) | ||
self.mask_list.update({op_name: mask}) | ||
self.if_init_list.update({op_name: False}) | ||
else: | ||
mask = self.mask_list[op_name] | ||
return mask | ||
|
||
|
||
class AGP_Pruner(Pruner): | ||
|
@@ -49,18 +57,23 @@ def __init__(self, config_list): | |
super().__init__(config_list) | ||
self.mask_list = {} | ||
self.now_epoch = 1 | ||
self.if_init_list = {} | ||
|
||
def calc_mask(self, weight, config, op_name, **kwargs): | ||
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight)) | ||
target_sparsity = self.compute_target_sparsity(config) | ||
k = int(weight.numel() * target_sparsity) | ||
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: | ||
return mask | ||
# if we want to generate new mask, we should update weigth first | ||
w_abs = weight.abs() * mask | ||
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() | ||
new_mask = torch.gt(w_abs, threshold).type_as(weight) | ||
self.mask_list[op_name] = new_mask | ||
if self.if_init_list.get(op_name, True): | ||
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight)) | ||
target_sparsity = self.compute_target_sparsity(config) | ||
k = int(weight.numel() * target_sparsity) | ||
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: | ||
return mask | ||
# if we want to generate new mask, we should update weigth first | ||
w_abs = weight.abs() * mask | ||
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() | ||
new_mask = torch.gt(w_abs, threshold).type_as(weight) | ||
self.mask_list.update({op_name: new_mask}) | ||
self.if_init_list.update({op_name: False}) | ||
else: | ||
new_mask = self.mask_list[op_name] | ||
return new_mask | ||
|
||
def compute_target_sparsity(self, config): | ||
|
@@ -86,35 +99,5 @@ def compute_target_sparsity(self, config): | |
def update_epoch(self, epoch): | ||
if epoch > 0: | ||
self.now_epoch = epoch | ||
|
||
|
||
class SensitivityPruner(Pruner): | ||
"""Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks" | ||
https://arxiv.org/pdf/1506.02626v3.pdf | ||
|
||
I.e.: "The pruning threshold is chosen as a quality parameter multiplied | ||
by the standard deviation of a layers weights." | ||
""" | ||
|
||
def __init__(self, config_list): | ||
""" | ||
config_list: supported keys: | ||
- sparsity: chosen pruning sparsity | ||
""" | ||
super().__init__(config_list) | ||
self.mask_list = {} | ||
|
||
def calc_mask(self, weight, config, op_name, **kwargs): | ||
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight)) | ||
# if we want to generate new mask, we should update weight first | ||
weight = weight * mask | ||
target_sparsity = config['sparsity'] * torch.std(weight).item() | ||
k = int(weight.numel() * target_sparsity) | ||
if k == 0: | ||
return mask | ||
|
||
w_abs = weight.abs() | ||
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() | ||
new_mask = torch.gt(w_abs, threshold).type_as(weight) | ||
self.mask_list[op_name] = new_mask | ||
return new_mask | ||
for k in self.if_init_list.keys(): | ||
self.if_init_list[k] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example uses AGP Pruner. Initiating a pruner needs a user provided configuration which can be provided in two ways: