This repository contains the training code of N2UQ introduced in our CVPR 2022 paper: "Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation"
In this study, we propose a quantization method that can learn the non-uniform input thresholds to maintain the strong representation ability of nonuniform methods, while output uniform quantized levels to be hardware-friendly and efficient as the uniform quantization for model inference.
To train the quantized network with learnable input thresholds, we introduce a generalized straight-through estimator (G-STE) for intractable backward derivative calculation w.r.t. threshold parameters.
The formula for N2UQ is simply as follows,
Forward pass:
Backward pass:
Moreover, we proposed L1 norm based entropy preserving weight regularization for weight quantization.
If you find our code useful for your research, please consider citing:
@inproceedings{liu2022nonuniform,
title={Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation},
author={Liu, Zechun and Cheng, Kwang-Ting and Huang, Dong and Xing, Eric and Shen, Zhiqiang},
journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022}
}
- python 3.6, pytorch 1.7.1, torchvision 0.8.2
- gdown
- Download ImageNet dataset
- pip install gdown # gdown will automatically download the models
- If gdown doesn't work, you may need to manually download the pretrained models and put them in the correponding
./models/
folder.
(1) For ResNet architectures:
- Change directory to
./resnet/
- Run
bash run.sh architecture n_bits quantize_downsampling
- E.g.,
bash run.sh resnet18 2 0
for quantize resnet18 to 2-bit without quantizing downsampling layers
(2) For MobileNet architectures:
- Change directory to
./mobilenetv2/
- Run
bash run.sh
Network | Methods | W2/A2 | W3/A3 | W4/A4 |
---|---|---|---|---|
ResNet-18 | ||||
PACT | 64.4 | 68.1 | 69.2 | |
DoReFa-Net | 64.7 | 67.5 | 68.1 | |
LSQ | 67.6 | 70.2 | 71.1 | |
N2UQ | 69.4 Model-Res18-2bit | 71.9 Model-Res18-3bit | 72.9 Model-Res18-4bit | |
N2UQ * | 69.7 Model-Res18-2bit | 72.1 Model-Res18-3bit | 73.1 Model-Res18-4bit | |
ResNet-34 | ||||
LSQ | 71.6 | 73.4 | 74.1 | |
N2UQ | 73.3 Model-Res34-2bit | 75.2 Model-Res34-3bit | 76.0 Model-Res34-4bit | |
N2UQ * | 73.4 Model-Res34-2bit | 75.3 Model-Res34-3bit | 76.1 Model-Res34-4bit | |
ResNet-50 | ||||
PACT | 64.4 | 68.1 | 69.2 | |
LSQ | 67.6 | 70.2 | 71.1 | |
N2UQ | 75.8 Model-Res50-2bit | 77.5 Model-Res50-3bit | 78.0 Model-Res50-4bit | |
N2UQ * | 76.4 Model-Res50-2bit | 77.6 Model-Res50-3bit | 78.0 Model-Res50-4bit |
Note that N2UQ without * denotes quantizing all the convolutional layers except the first input convolutional layer.
N2UQ with * denotes quantizing all the convolutional layers except the first input convolutional layer and three downsampling layers.
W2/A2, W3/A3, W4/A4 denote the cases where the weights and activations are both quantized to 2 bits, 3 bits, and 4 bits, respectively.
Network | Methods | W4/A4 |
---|---|---|
MobileNet-V2 | N2UQ | 72.1 Model-MBV2-4bit |
Zechun Liu, HKUST (zliubq at connect.ust.hk)