-
Notifications
You must be signed in to change notification settings - Fork 74
/
Copy pathwage_qtorch.py
67 lines (56 loc) · 1.8 KB
/
wage_qtorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
from torch.autograd import Function
from qtorch.quant import fixed_point_quantize, quantizer
from qtorch import FixedPoint
def shift(x):
max_entry = x.abs().max()
return x / 2.0 ** torch.ceil(torch.log2(max_entry))
def C(x, bits):
if bits > 15 or bits == 1:
delta = 0
else:
delta = 1.0 / (2.0 ** (bits - 1))
upper = 1 - delta
lower = -1 + delta
return torch.clamp(x, lower, upper)
def QW(x, bits, scale=1.0, mode="nearest"):
y = fixed_point_quantize(
x, wl=bits, fl=bits - 1, clamp=True, symmetric=True, rounding=mode
)
# per layer scaling
if scale > 1.8:
y /= scale
return y
def QG(x, bits_G, bits_R, lr, mode="nearest"):
x = shift(x)
lr = lr / (2.0 ** (bits_G - 1))
norm = fixed_point_quantize(
lr * x, wl=bits_G, fl=bits_G - 1, clamp=False, symmetric=True, rounding=mode
)
return norm
class WAGEQuantizer(Module):
def __init__(self, bits_A, bits_E, A_mode="nearest", E_mode="nearest"):
super(WAGEQuantizer, self).__init__()
self.activate_number = (
FixedPoint(wl=bits_A, fl=bits_A - 1, clamp=True, symmetric=True)
if bits_A != -1
else None
)
self.error_number = (
FixedPoint(wl=bits_E, fl=bits_E - 1, clamp=True, symmetric=True)
if bits_E != -1
else None
)
self.quantizer = quantizer(
forward_number=self.activate_number,
forward_rounding=A_mode,
backward_number=self.error_number,
backward_rounding=E_mode,
clamping_grad_zero=True,
backward_hooks=[shift],
)
def forward(self, x):
return self.quantizer(x)