Skip to content

Commit

Permalink
add BNN quantization algorithm (microsoft#1832)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk authored and chicm-ms committed Dec 24, 2019
1 parent b0c0eb7 commit 0c7f22f
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 15 deletions.
110 changes: 97 additions & 13 deletions docs/en_US/Compressor/Quantizer.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
Quantizer on NNI Compressor
===

## Naive Quantizer

We provide Naive Quantizer to quantizer weight to default 8 bits, you can use it to test quantize algorithm without any configure.
Expand Down Expand Up @@ -53,11 +52,24 @@ You can view example for more information

#### User configuration for QAT Quantizer
* **quant_types:** : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'

type of quantization you want to apply, currently support 'weight', 'input', 'output'.

* **op_types:** list of string

specify the type of modules that will be quantized. eg. 'Conv2D'

* **op_names:** list of string

specify the name of modules that will be quantized. eg. 'conv1'

* **quant_bits:** int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
when the type is int, all quantization types share same bits length

bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length.

* **quant_start_step:** int

disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0

Expand All @@ -71,22 +83,94 @@ In [DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bit
### Usage
To implement DoReFa Quantizer, you can add code below before your training code

Tensorflow code
```python
from nni.compressors.tensorflow import DoReFaQuantizer
config_list = [{ 'q_bits': 8, 'op_types': 'default' }]
quantizer = DoReFaQuantizer(tf.get_default_graph(), config_list)
quantizer.compress()
```
PyTorch code
```python
from nni.compressors.torch import DoReFaQuantizer
config_list = [{ 'q_bits': 8, 'op_types': 'default' }]
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types': 'default'
}]
quantizer = DoReFaQuantizer(model, config_list)
quantizer.compress()
```

You can view example for more information

#### User configuration for DoReFa Quantizer
* **q_bits:** This is to specify the q_bits operations to be quantized to
* **quant_types:** : list of string

type of quantization you want to apply, currently support 'weight', 'input', 'output'.

* **op_types:** list of string

specify the type of modules that will be quantized. eg. 'Conv2D'

* **op_names:** list of string

specify the name of modules that will be quantized. eg. 'conv1'

* **quant_bits:** int or dict of {str : int}

bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length.


## BNN Quantizer
In [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830),

>We introduce a method to train Binarized Neural Networks (BNNs) - neural networks with binary weights and activations at run-time. At training-time the binary weights and activations are used for computing the parameters gradients. During the forward pass, BNNs drastically reduce memory size and accesses, and replace most arithmetic operations with bit-wise operations, which is expected to substantially improve power-efficiency.

### Usage

PyTorch code
```python
from nni.compression.torch import BNNQuantizer
model = VGG_Cifar10(num_classes=10)

configure_list = [{
'quant_types': ['weight'],
'quant_bits': 1,
'op_types': ['Conv2d', 'Linear'],
'op_names': ['features.0', 'features.3', 'features.7', 'features.10', 'features.14', 'features.17', 'classifier.0', 'classifier.3']
}, {
'quant_types': ['output'],
'quant_bits': 1,
'op_types': ['Hardtanh'],
'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5']
}]

quantizer = BNNQuantizer(model, configure_list)
model = quantizer.compress()
```

You can view example [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py) for more information.

#### User configuration for BNN Quantizer
* **quant_types:** : list of string

type of quantization you want to apply, currently support 'weight', 'input', 'output'.

* **op_types:** list of string

specify the type of modules that will be quantized. eg. 'Conv2D'

* **op_names:** list of string

specify the name of modules that will be quantized. eg. 'conv1'

* **quant_bits:** int or dict of {str : int}

bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length.

### Experiment
We implemented one of the experiments in [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830), we quantized the **VGGNet** for CIFAR-10 in the paper. Our experiments results are as follows:

| Model | Accuracy |
| ------------- | --------- |
| VGGNet | 86.93% |


The experiments code can be found at [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py)
155 changes: 155 additions & 0 deletions examples/model_compress/BNN_quantizer_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import BNNQuantizer


class VGG_Cifar10(nn.Module):
def __init__(self, num_classes=1000):
super(VGG_Cifar10, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128, eps=1e-4, momentum=0.1),
nn.Hardtanh(inplace=True),

nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(128, eps=1e-4, momentum=0.1),
nn.Hardtanh(inplace=True),

nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-4, momentum=0.1),
nn.Hardtanh(inplace=True),


nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(256, eps=1e-4, momentum=0.1),
nn.Hardtanh(inplace=True),


nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512, eps=1e-4, momentum=0.1),
nn.Hardtanh(inplace=True),


nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(512, eps=1e-4, momentum=0.1),
nn.Hardtanh(inplace=True)
)

self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 1024, bias=False),
nn.BatchNorm1d(1024),
nn.Hardtanh(inplace=True),
nn.Linear(1024, 1024, bias=False),
nn.BatchNorm1d(1024),
nn.Hardtanh(inplace=True),
nn.Linear(1024, num_classes), # do not quantize output
nn.BatchNorm1d(num_classes, affine=False)
)


def forward(self, x):
x = self.features(x)
x = x.view(-1, 512 * 4 * 4)
x = self.classifier(x)
return x


def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
for name, param in model.named_parameters():
if name.endswith('old_weight'):
param = param.clamp(-1, 1)
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, acc))
return acc

def adjust_learning_rate(optimizer, epoch):
update_list = [55, 100, 150, 200, 400, 600]
if epoch in update_list:
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 0.1
return

def main():
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)

model = VGG_Cifar10(num_classes=10)
model.to(device)

configure_list = [{
'quant_types': ['weight'],
'quant_bits': 1,
'op_types': ['Conv2d', 'Linear'],
'op_names': ['features.3', 'features.7', 'features.10', 'features.14', 'classifier.0', 'classifier.3']
}, {
'quant_types': ['output'],
'quant_bits': 1,
'op_types': ['Hardtanh'],
'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5']
}]

quantizer = BNNQuantizer(model, configure_list)
model = quantizer.compress()

print('=' * 10 + 'train' + '=' * 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
best_top1 = 0
for epoch in range(400):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
adjust_learning_rate(optimizer, epoch)
top1 = test(model, device, test_loader)
if top1 > best_top1:
best_top1 = top1
print(best_top1)


if __name__ == '__main__':
main()
34 changes: 32 additions & 2 deletions src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
import torch
from .compressor import Quantizer
from .compressor import Quantizer, QuantGrad, QuantType

__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer']

Expand Down Expand Up @@ -240,4 +240,34 @@ def quantize_weight(self, weight, config, **kwargs):
def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale
return output
return output


class ClipGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type):
if quant_type == QuantType.QUANT_OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
return grad_output


class BNNQuantizer(Quantizer):
"""Binarized Neural Networks, as defined in:
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
self.quant_grad = ClipGrad

def quantize_weight(self, weight, config, **kwargs):
out = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out

def quantize_output(self, output, config, **kwargs):
out = torch.sign(output)
# remove zeros
out[out == 0] = 1
return out

0 comments on commit 0c7f22f

Please sign in to comment.