Pytorch FakeQuantizer Observer with Learned Step Size Quantization(LSQ+) (https://arxiv.org/abs/2004.09576) (LSQ https://arxiv.org/abs/1902.08153 and LSQ+ https://arxiv.org/abs/2004.09576).
Can be considered as replacement of FakeQuantize Observer in Pytorch and moreover it fully integrated in torch.quantization environment
(I am not the author any of mentioned articles, I just implement this for my own purposes)
The Pytorch quantization scheme is x_q = round(clamp(x/scale + zero_point,quant_min,quant_max)
;
where zero_point
is integer and have same type as x_q
.
In the article, however, used another approach: ```x_q = round(clamp((x - shift)/scale,quant_min,quant_max);
shift` - have float type here. We adapt such scheme via implicit conversion of `shift` to `zero_point`
Following article, LSQ+ emulate the quantization and dequantization of input treating scale
and shift
as learnable parameters.
(quantize->dequantize):
1) Emualate zero_point behaviour zero_point = clamp(-shift/scale, type_min, type_max); where type_min(type_max) is min(max) of quantized type (i.e int8/uint8)
2) Quantized: x_q = round(clamp(x/scale + zero_point, quant_min, quant_max)
3) Dequantized: x_r = x_q * scale + shift
Consider, 1) zero_point = clamp(-shift/scale, type_min, type_max);
2) x_q = round(clamp(x/scale + zero_point, quant_min, quant_max));
w.r.t input: 1, if ((quant_min < x_q) && (x_q < quant_max))
0, otherwise
w.r.t scale: (x_r - x)/scale, if ((quant_min < x_q) && (x_q < quant_max))
quant_min - zero_point, if x_q <= quant_min
quant_max - zero_point, if x_q >= quant_max
w.r.t shfit: 0, if ((quant_min < x_q) && x_q < quant_max))
1, otherwise
-
Based on article, we support scaling of derivatives on
scale
andshift
on1/sqrt(n_elements(x)*quant_max)
(See https://arxiv.org/abs/1902.08153 for details) -
Pytorch Quantization assumes using
qint8
for weights quantization andquint8
for activation. Based on that,qint8
dtype considered as Observer for weights andquint8
for activation. -
Weight case initialization: Is always happens statically following by: s_init = max(|µ − 3 ∗ σ|, |µ + 3 ∗ σ|)/2^b−1, where µ and σ mean and std of weights tensor correspondingly.
-
Activation case initializations:
a) In original article the authors propose to init
scale
andshift
for affine case by minimizing: ||x_r-x||F^2 Instead of using it as additional loss, we incorporate this natively inside code, due to this function supposed to be work inside Pytorch Observer class.b) Alternatively, their is possibility to initialize
scale
andshift
via observer. We highly suggest use the MovingAverage(PerChannel)MinMaxObserver to track running min and max valuesFor control a) and b) use
init_mode
argument during initialization The number of batches during which initialization happens controls byn_batches
. -
LSQ can work with any bitrate, however for Pytorch compatiblity, we limited bitrate from above to 8-bit.
-
By default quant_min and quant_max will be initialized for 7-bit, for avoiding overflow during quantized inference in Pytorch.
qint8 - quant_min, quant_max = -64, 63 quint8 - quant_min, quant_max = 0, 127
To overcome this, look on
avoid_torch_overflow
argument.
C++17 must be supported by your compiler! (due to constexpr in code)
PyTorch >= 1.7.0;
- Clone this repo and
cd LSQFakeQuantize-PyTorch
- (optional) If you compile with CUDA, please pass path to nvcc to CUDA_HOME env variable!
- Important! There is bug in PyTorch which can lead to crash during build under CUDA.
This bug was fixed in PyTorch 1.8. However it easy to fix it in previous versions.
Run
python torch_patch.py
(anyway it will automatically run during step 3) to fix it. This script change a few lines of code in single C++ header file, however doing this directly in python dist-package folder. Please, be sure that you have rights for changing files inside this folder! Anyway, you should do it only once for each python environment(PyTorch package). (If something will going wrong, please inspecttorch_patch.py
first (it very simple) and try to reproduce patch manually.) - Run
python setup.py install
orpython setup.py bdist_wheel
- to install/build package
We strictly recommend to study the Pytorch Quantization manuals (how to prepare the model and how the observers concept work) before using, because this LSQFakeQuantize implementation is part of it!
After preparing model to quantization, you MUST make a test forward pass (at least with random tensor) BEFORE adding model parameters to OPTIMIZATOR!
Code example (How to setup Qconfig):
import torch
from torch.quantization import QConfig
from torchlsq import LSQFakeQuantizer
act_ext_obs = torch.quantization.MovingAverageMinMaxObserver
wei_ext_obs = torch.quantization.MovingAveragePerChannelMinMaxObserver
init_mode = 'observer' # this will initialize the params via ext_obs above, based on tracked statistic
# init_mode = 'learnable' # initialization like in LSQ+ article
if init_mode != 'observer':
act_ext_obs = wei_ext_obs = None
lsq_qconfig = QConfig(activation=LSQFakeQuantizer.with_args(observer=act_ext_obs,
otype='activation',
init_mode=init_mode,
avoid_torch_overflow=True, # 7bits),
weight=LSQFakeQuantizer.with_args(observer=wei_ext_obs,
otype='weight',
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
init_mode=init_mode,
avoid_torch_overflow=False, # 8bits))
Additional options for the layer:
observer(torch.quantization.ObserverBase) - Module for observing statistics on input tensors and calculating scale and zero-point.
Will be used only if `init_mode == "observer"`. In other case, have no effect.
otype(str) - Type specification for which LSQ will used. Can be one of (`weight`, `activation`)
dtype(torch.dtype) - quint8 or qint8. Default: torch.quint8.
qscheme(torch.qscheme) - per_tensor_affine, per_tensor_symmetric, per_channel_affine or
per_channel_symmetric schemes. Default: torch.per_tensor_affine.
quant_min(int) - lower bound of quantized range. Default: None.
quant_max(int) - upper bound of quantized range. Default: None.
init_scale(float) - init value for scale, have effect for "activation" otype only. Default: 1.
init_shift(float) - init value for shift, have effect for "affine" qschemes. Default: 0
ch_axis(int) - dimension for per channel quantization scheme. Default: None.(1 - for activation, 0 - for weights)
learn_params(bool) - enable learning via LSQ, otherwise the module will work like FakeQuantizer. Default: True
init_batches(int) - Number of batches during which parameter initalization will happen,
has effect only for "activation" otype. Default: 1000.
init_mode(str) - Controls initialization mode for parameters: via backprop('learnable') of via 'observer'. Default: 'observer'.
use_grad_scaling(bool) - Apply scale to scale and zero_point gradient. Default: True.
grad_scaler(float) - additional grad scaler, helps to compensate the learning rate. Default: 1.
avoid_torch_overflow(bool) - Reduce 8-bit to 7-bit during default initialization, to avoid overflows in quantized operations. Default: True.
debug_mode(bool) - Just pass input during forward if enabled. Default: False.
1.0 - First release, has a lot of bugs with zero_point and poor class design.
2.0 - Full refactor of python part, most of known bugs fixed.
2.1 - Replaced scale gradient; minor enhancements and bugfixes
- Add unit tests