This project is the implementation of the DeepShift: Towards Multiplication-Less Neural Networks paper, that aims to replace multiplications in a neural networks with bitwise shift (and sign change).
[Paper] - [arXiv] - [Video] - [Presentation]
This research project was done at Huawei Technologies.
If you find this code useful, please cite our paper:
@InProceedings{Elhoushi_2021_CVPR,
author = {Elhoushi, Mostafa and Chen, Zihao and Shafiq, Farhan and Tian, Ye Henry and Li, Joey Yiwei},
title = {DeepShift: Towards Multiplication-Less Neural Networks},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {June},
year = {2021},
pages = {2359-2368}
}
- Overview
- Important Notes
- Getting Started
- Running the Bitwise Shift CUDA & CPU Kernels
- Results
- Code WalkThrough
The main idea of DeepShift is to test the ability to train and infer using bitwise shifts.
We present 2 approaches:
- DeepShift-Q: the parameters are floating point weights just like regular networks, but the weights are rounded to powers of 2 during the forward and backward passes
- DeepShift-PS: the parameters are signs and shift values
- To train from scratch, the learning rate
--lr
option should be set to 0.01. To train from pre-trained model, it should be set to 0.001 andlr-step-size
should be set to 5 - To use DeepShift-PS, the
--optimizer
must be set toradam
in order to obtain good results.
- Clone the repo:
git clone https://github.com/mostafaelhoushi/DeepShift.git
- Change directory
cd DeepShift
- Create virtual environment:
virtualenv venv --prompt="(DeepShift) " --python=/usr/bin/python3.6
- (Needs to be done every time you run code) Source the environment:
source venv/bin/activate
- Install required packages and build the spfpm package for fixed point
pip install -r requirements.txt
- cd into
pytorch
directroy:
cd pytorch
- To list all the available options, you may run:
python <dataset>.py --help
where <dataset>
can be either mnist
, cifar10
, imagenet
.
When you run any training or evaluation script, you will have the model binary file as well as the training log in ./models/<dataset>/<arch>/<shift-type><shift-depth><weight-bit-width><desc>
where:
<shift-type>
is eithershift_q
if you pass--shift-type Q
,shift_ps
if you pass--shift-type PS
, orshift_0
if you are running the default FP32 baseline<shift-depth>
is the number of layers from the start of the model that have been converted to DeepShift. This is determined by theshift-depth
argument
- Now you can run the different scripts with different options, e.g.,
a) Train a DeepShift simple fully-connected model on the MNIST dataset, using the PS apprach:
b) Train a DeepShift simple convolutional model on the MNIST dataset, using the Q approach:
python mnist.py --shift-depth 3 --shift-type PS --optimizer radam
c) Train a DeepShift ResNet20 on the CIFAR10 dataset from scratch:python mnist.py --type conv --shift-depth 3 --shift-type Q
d) Train a DeepShift ResNet18 model on the Imagenet dataset using converted pretrained weights for 5 epochs with learning rate 0.001:python cifar10.py --arch resnet20 --pretrained False --shift-depth 1000 --shift-type Q
e) Train a DeepShift ResNet18 model on the Imagenet dataset from scratch with an initial learning rate of 0.01:python imagenet.py <path to imagenet dataset> --arch resnet18 --pretrained True --shift-depth 1000 --shift-type Q --epochs 5 --lr 0.001
f) Train a DeepShift ResNet18 model on the CIFAR10 dataset from scratch with 8-bit fixed point activation (3-bits for integers and 5-bits for fractions):python imagenet.py <path to imagenet dataset> --arch resnet18 --pretrained False --shift-depth 1000 --shift-type PS --optimizer radam --lr 0.01
python cifar10.py --arch resnet18 --pretrained False --shift-depth 1000 --shift-type PS --optimizer radam --lr 0.01 -ab 3 5
- cd into
DeepShift/pytorch
directroy:
cd DeepShift/pytorch
- Run the installation script to install our CPU and CUDA kernels that perform matrix multiplication and convolution using bit-wise shifts:
sh install_kernels.sh
- Now you can run a model with acutal bit-wise shift kernels in CUDA using the
--use-kernel True
option. Remember that the kernel only works for inference not training, so you need to add the-e
option as well:
python imagenet.py --arch resnet18 -e --shift-depth 1000 --pretrained True --use-kernel True
- To compare the latency with a naive regular convolution kernel that does not include cuDNN's other optimizations:
python imagenet.py --arch resnet18 -e --pretrained True --use-kernel True
Model | Original | DeepShift-Q | DeepShift-PS |
---|---|---|---|
Simple FC Model | 96.92% [1] | 97.03% [2] | 98.26% [3] |
Simple Conv Model | 98.75% [4] | 98.81% [5] | 99.12% [6] |
Commands to reproduce results:
-
python mnist.py
-
python mnist.py --shift-depth 1000 --shift-type Q
-
python mnist.py --shift-depth 1000 --shift-type PS --opt radam
-
python mnist.py --type conv
-
python mnist.py --type conv --shift-depth 1000 --shift-type Q
-
python mnist.py --type conv --shift-depth 1000 --shift-type PS --opt radam
Model | Original | DeepShift-Q | DeepShift-PS |
---|---|---|---|
Simple FC Model | 96.92% [1] | 97.85% [7] | 98.26% [8] |
Simple Conv Model | 98.75% [4] | 99.15% [9] | 99.16% [10] |
Commands to reproduce results (assumes you have run commands [1] and [2] in order to have the baseline pretrained weights):
-
python mnist.py --weights ./models/mnist/simple_linear/shift_0/weights.pth --shift-depth 1000 --shift-type Q --desc from_pretrained
-
python mnist.py --weights ./models/mnist/simple_linear/shift_0/weights.pth --shift-depth 1000 --shift-type PS --opt radam --desc from_pretrained
-
python mnist.py --type conv --weights ./models/mnist/simple_conv/shift_0/weights.pth --shift-depth 1000 --shift-type Q --desc from_pretrained
-
python mnist.py --type conv --weights ./models/mnist/simple_conv/shift_0/weights.pth --shift-depth 1000 --shift-type PS --opt radam --desc from_pretrained
Model | Original [11] | DeepShift-Q [12] | DeepShift-PS [13] |
---|---|---|---|
resnet18 | 94.45% | 94.42% | 93.20% |
mobilenetv2 | 93.57% | 93.63% | 92.64% |
resnet20 | 91.79% | 89.85% | 88.84% |
resnet32 | 92.39% | 91.13% | 89.97% |
resnet44 | 92.84% | 91.29% | 90.92% |
resnet56 | 93.46% | 91.52% | 91.11% |
Commands to reproduce results:
-
python cifar10.py --arch <Model>
-
python cifar10.py --arch <Model> --shift-depth 1000 --shift-type Q
-
python cifar10.py --arch <Model> --shift-depth 1000 --shift-type PS --opt radam
Model | Original [11] | DeepShift-Q [14] | DeepShift-PS [15] |
---|---|---|---|
resnet18 | 94.45% | 94.25% | 94.12% |
mobilenetv2 | 93.57% | 93.04% | 92.78% |
Commands to reproduce results (assumes you have run command [11] for the corresponding architecture in order to have the baseline pretrained weights):
-
python cifar10.py --arch <Model> --weights ./models/cifar10/<Model>/shift_0/checkpoint.pth.tar --shift-depth 1000 --shift-type Q --desc from_pretrained --lr 1e-3 --lr-step 5 --epochs 15
-
python cifar10.py --arch <Model> --weights ./models/cifar10/<Model>/shift_0/checkpoint.pth.tar --shift-depth 1000 --shift-type PS --opt radam --desc from_pretrained --lr 1e-3 --lr-step 5 --epochs 15
Model | Type | Weight Bits | Train from Scratch | Train from Pre-Trained |
---|---|---|---|---|
resnet18 | Original | 32 | 94.45% [11] | - |
resnet18 | DeepShift-PS | 5 | 93.20% [13] | 94.12% [15] |
resnet18 | DeepShift-PS | 4 | 94.12% [16] | 94.13% [17] |
resnet18 | DeepShift-PS | 3 | 92.85% [18] | 91.16% [19] |
resnet18 | DeepShift-PS | 2 | 92.80% [20] | 90.68% [21] |
Commands to reproduce results (assumes you have run command [11] for the corresponding architecture in order to have the baseline pretrained weights):
-
python cifar10.py --arch <Model> --shift-depth 1000 --shift-type PS -wb 4 --opt radam
-
python cifar10.py --arch <Model> --weights ./models/cifar10/<Model>/shift_0/checkpoint.pth.tar --shift-depth 1000 --shift-type PS -wb 4 --opt radam --desc from_pretrained --lr 1e-3 --lr-step 5 --epochs 15
-
python cifar10.py --arch <Model> --shift-depth 1000 --shift-type PS -wb 3 --opt radam
-
python cifar10.py --arch <Model> --weights ./models/cifar10/<Model>/shift_0/checkpoint.pth.tar --shift-depth 1000 --shift-type PS -wb 3 --opt radam --desc from_pretrained --lr 1e-3 --lr-step 5 --epochs 15
-
python cifar10.py --arch <Model> --shift-depth 1000 --shift-type PS -wb 2 --opt radam
-
python cifar10.py --arch <Model> --weights ./models/cifar10/<Model>/shift_0/checkpoint.pth.tar --shift-depth 1000 --shift-type PS -wb 2 --opt radam --desc from_pretrained --lr 1e-3 --lr-step 5 --epochs 15
Accuracies shown are Top1 / Top5.
Model | Original [22] | DeepShift-Q [23] | DeepShift-PS [24] |
---|---|---|---|
resnet18 | 69.76% / 89.08% | 65.32% / 86.29% | 65.34% / 86.05% |
resnet50 | 76.13% / 92.86% | 70.70% / 90.20% | 71.90% / 90.20% |
vgg16 | 71.59% / 90.38% | 70.87% / 90.09% | TBD |
Commands to reproduce results:
- a) To evaluate PyTorch pretrained models:
python imagenet.py --arch <Model> --pretrained True -e <path_to_imagenet_dataset>
OR b) To train from scratch:python imagenet.py --arch <Model> --pretrained False <path_to_imagenet_dataset>
-
python imagenet.py --arch <Model> --pretrained False --shift-depth 1000 --shift-type Q --desc from_scratch --lr 0.01 <path_to_imagenet_dataset>
-
python imagenet.py --arch <Model> --pretrained False --shift-depth 1000 --shift-type PS --desc from_scratch --lr 0.01 --opt radam <path_to_imagenet_dataset>
Model | Original [22] | DeepShift-Q [25] | DeepShift-PS [26] |
---|---|---|---|
resnet18 | 69.76% / 89.08% | 69.56% / 89.17% | 69.27% / 89.00% |
resnet50 | 76.13% / 92.86% | 76.33% / 93.05% | 75.93% / 92.90% |
googlenet | 69.78% / 89.53% | 70.73% / 90.17% | 69.87% / 89.62% |
vgg16 | 71.59% / 90.38% | 71.56% / 90.48% | 71.39% / 90.33% |
alexnet | 56.52% / 79.07% | 55.81% / 78.79% | 55.90% / 78.73% |
densenet121 | 74.43% / 91.97% | 74.52% / 92.06% | TBD |
-
python imagenet.py --arch <Model> --pretrained True --shift-depth 1000 --shift-type Q --desc from_pretrained --lr 1e-3 --lr-step 5 --epochs 15 <path_to_imagenet_dataset>
-
python imagenet.py --arch <Model> --pretrained True --shift-depth 1000 --shift-type PS --desc from_pretrained --lr 1e-3 --lr-step 5 --epochs 15 --opt radam <path_to_imagenet_dataset>
Model | Type | Weight Bits | Train from Scratch | Train from Pre-Trained |
---|---|---|---|---|
resNet18 | Original | 32 | 69.76% / 89.08% | - |
resNet18 | DeepShift-Q | 5 | 65.34% / 86.05% | 69.56% / 89.17% |
resNet18 | DeepShift-PS | 5 | 65.34% / 86.05% | 69.27% / 89.00% |
resNet18 | DeepShift-Q | 4 | TBD | 69.56% / 89.14% |
resNet18 | DeepShift-PS | 4 | 67.07% / 87.36% | 69.02% / 88.73% |
resNet18 | DeepShift-PS | 3 | 63.11% / 84.45% | TBD |
resNet18 | DeepShift-PS | 2 | 60.80% / 83.01% | TBD |
- CIFAR10:
- DeepShift PS, 5-bit Weights, Train from Scratch: checkpoint.pth.tar
pytorch
: directory containing implementation, tests, and saved models using PyTorchdeepshift
: directory containing the PyTorch models as well as the CUDA and CPU kernels ofLinearShift
andConv2dShift
opsunoptimized
: directory containing the PyTorch models as well as the CUDA and CPU kernels of the naive implementations ofLinear
andConv2d
opsmnist.py
: example script to train and infer on MNIST dataset using simple models in both their original forms and DeepShift version.cifar10.py
: example script to train and infer on CIFAR10 dataset using various models in both their original forms and DeepShift version.imagenet.py
: example script to train and infer on Imagenet dataset using various models in both their original forms and DeepShift version.optim
: directory containing definition of RAdam and Ranger optimizers. RAdam optimizer is crucial to get DeepShift-PS obtain the accuracies shown here