Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix pruner bugs and add model compression README #1624

Merged
merged 14 commits into from
Oct 21, 2019
47 changes: 47 additions & 0 deletions examples/model_compress/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Run model compression examples

You can run these examples easily like this, take torch pruning for example

```bash
python main_torch_pruner.py
```

Model compression can be configured in 2 ways
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example uses AGP Pruner. Initiating a pruner needs a user provided configuration which can be provided in two ways:


- By reading ```configure_example.yaml```, this can make codes clean when your configuration is complicated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codes -> code

- Directly config in your codes

In our example, we simply config model compression in our codes like this

```python
configure_list = [{
'initial_sparsity': 0,
'final_sparsity': 0.8,
'start_epoch': 1,
'end_epoch': 11,
'frequency': 1,
'op_type': 'default'
}]
pruner = AGP_Pruner(configure_list)
```

Please notice that when ```pruner(model)``` called, our model compression codes will be **automatically injected** and you can fine-tune your model **without** any modifications, masked weights **won't** be updated any more during fine tuning.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When pruner(model) is called, your model is injected with masks as embedded operations. For example, a layer takes a weight as input, we will insert an operation between the weight and the layer, this operation takes the weight as input and outputs a new weight applied by the mask. Thus, the masks are applied at any time the computation goes through the operations. You can fine-tune your model without any modifications.


```python
for epoch in range(10):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
pruner.update_epoch(epoch + 1)
```

QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
When fine tuning finished, pruned weights are all masked and you can get masks like this

```
masks = pruner.mask_list
layer_name = xxx
mask = masks[layer_name]
```



38 changes: 20 additions & 18 deletions examples/model_compress/main_torch_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def forward(self, x):
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim = 1)
return F.log_softmax(x, dim=1)


def train(model, device, train_loader, optimizer):
Expand All @@ -35,6 +35,7 @@ def train(model, device, train_loader, optimizer):
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
model.eval()
test_loss = 0
Expand All @@ -43,52 +44,53 @@ def test(model, device, test_loader):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction = 'sum').item()
pred = output.argmax(dim = 1, keepdim = True)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))


def main():
torch.manual_seed(0)
device = torch.device('cpu')

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train = True, download = True, transform = trans),
batch_size = 64, shuffle = True)
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train = False, transform = trans),
batch_size = 1000, shuffle = True)
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()

'''you can change this to SensitivityPruner to implement it
pruner = SensitivityPruner(configure_list)
'''
configure_list = [{
'initial_sparsity': 0,
'final_sparsity': 0.8,
'start_epoch': 1,
'end_epoch': 10,
'frequency': 1,
'op_type': 'default'
}]
'initial_sparsity': 0,
'final_sparsity': 0.8,
'start_epoch': 1,
'end_epoch': 11,
'frequency': 1,
'op_type': 'default'
}]

pruner = AGP_Pruner(configure_list)
pruner(model)
# you can also use compress(model) method
# like that pruner.compress(model)

optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.5)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
pruner.update_epoch(epoch)

pruner.update_epoch(epoch + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss whether epoch should start from 0 or 1 next time.



if __name__ == '__main__':
Expand Down
83 changes: 33 additions & 50 deletions src/sdk/pynni/nni/compression/torch/builtin_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from .compressor import Pruner

__all__ = ['LevelPruner', 'AGP_Pruner', 'SensitivityPruner']
__all__ = ['LevelPruner', 'AGP_Pruner']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you removed sensitivitypruner, so please update doc accordingly


logger = logging.getLogger('torch pruner')

Expand All @@ -17,14 +17,22 @@ def __init__(self, config_list):
- sparsity
"""
super().__init__(config_list)
self.mask_list = {}
self.if_init_list = {}

def calc_mask(self, weight, config, **kwargs):
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
return torch.gt(w_abs, threshold).type_as(weight)
def calc_mask(self, weight, config, op_name, **kwargs):
if self.if_init_list.get(op_name, True):
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list.update({op_name: mask})
self.if_init_list.update({op_name: False})
else:
mask = self.mask_list[op_name]
return mask


class AGP_Pruner(Pruner):
Expand All @@ -49,18 +57,23 @@ def __init__(self, config_list):
super().__init__(config_list)
self.mask_list = {}
self.now_epoch = 1
self.if_init_list = {}

def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask
if self.if_init_list.get(op_name, True):
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_list[op_name]
return new_mask

def compute_target_sparsity(self, config):
Expand All @@ -86,35 +99,5 @@ def compute_target_sparsity(self, config):
def update_epoch(self, epoch):
if epoch > 0:
self.now_epoch = epoch


class SensitivityPruner(Pruner):
"""Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf

I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
"""

def __init__(self, config_list):
"""
config_list: supported keys:
- sparsity: chosen pruning sparsity
"""
super().__init__(config_list)
self.mask_list = {}

def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
# if we want to generate new mask, we should update weight first
weight = weight * mask
target_sparsity = config['sparsity'] * torch.std(weight).item()
k = int(weight.numel() * target_sparsity)
if k == 0:
return mask

w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask
return new_mask
for k in self.if_init_list.keys():
self.if_init_list[k] = True
11 changes: 4 additions & 7 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,23 @@ def compress(self, model):
if config is not None:
self._instrument_layer(layer, config)


def bind_model(self, model):
"""This method is called when a model is bound to the compressor.
Users can optionally overload this method to do model-specific initialization.
It is guaranteed that only one model will be bound to each compressor instance.
"""
pass

def update_epoch(self, epoch):
"""if user want to update model every epoch, user can override this method
"""
pass

def step(self):
"""if user want to update model every step, user can override this method
"""
pass


def _instrument_layer(self, layer, config):
raise NotImplementedError()

Expand Down Expand Up @@ -90,7 +88,6 @@ def calc_mask(self, weight, config, op, op_type, op_name):
"""
raise NotImplementedError("Pruners must overload calc_mask()")


def _instrument_layer(self, layer, config):
# TODO: support multiple weight tensors
# create a wrapper forward function to replace the original one
Expand All @@ -112,7 +109,7 @@ def new_forward(*input):
return ret

layer.module.forward = new_forward


class Quantizer(Compressor):
"""Base quantizer for pytorch quantizer"""
Expand All @@ -123,7 +120,7 @@ def __init__(self, config_list):
def __call__(self, model):
self.compress(model)
return model

def quantize_weight(self, weight, config, op, op_type, op_name):
"""user should know where dequantize goes and implement it in quantize method
we now do not provide dequantize method
Expand Down
Loading