-
Notifications
You must be signed in to change notification settings - Fork 7
/
gptq_utils.py
310 lines (257 loc) · 10.1 KB
/
gptq_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import math
import time
import tqdm
import torch
import torch.nn as nn
import logging
from flatquant.utils import cleanup_memory
from flatquant.quant_utils import WeightQuantizer
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def find_qlayers(module, layers=[torch.nn.Linear, ], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_qlayers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
class GPTQ:
def __init__(self, layer):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
def add_batch(self, inp, out):
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())
def fasterquant(
self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False
):
W = self.layer.weight.data.clone()
W = W.float()
tick = time.time()
if not self.quantizer.ready():
self.quantizer.find_params(W)
H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
if static_groups:
import copy
groups = []
for i in range(0, self.columns, groupsize):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i:(i + groupsize)])
groups.append(quantizer)
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if groupsize != -1:
if not static_groups:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)])
else:
idx = i1 + i
if actorder:
idx = perm[idx]
self.quantizer = groups[idx // groupsize]
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
torch.cuda.synchronize()
if actorder:
Q = Q[:, invperm]
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
if torch.any(torch.isnan(self.layer.weight.data)):
logging.warning('NaN in weights')
import pprint
pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point)
raise ValueError('NaN in weights')
def free(self):
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
cleanup_memory(verbose=False)
@torch.no_grad()
def gptq_fwrd(model, dataloader, dev, args):
'''
From GPTQ repo
TODO: Make this function general to support both OPT and LLaMA models
'''
logging.info('-----GPTQ Quantization-----')
use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.layers
model.model.embed_tokens = model.model.embed_tokens.to(dev)
model.model.norm = model.model.norm.to(dev)
if hasattr(model.model, "rotary_emb"):
model.model.rotary_emb = model.model.rotary_emb.to(dev)
layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {'i': 0, 'attention_mask': None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache['i']] = inp
cache['i'] += 1
cache['attention_mask'] = kwargs['attention_mask']
cache['position_ids'] = kwargs['position_ids']
raise ValueError
layers[0] = Catcher(layers[0])
for batch in dataloader:
try:
model(batch[0].to(dev))
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.embed_tokens = model.model.embed_tokens.cpu()
model.model.norm = model.model.norm.cpu()
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask']
position_ids = cache['position_ids']
quantizers = {}
sequential = [
['self_attn.k_proj.linear', 'self_attn.v_proj.linear', 'self_attn.q_proj.linear'],
['self_attn.o_proj.linear'],
['mlp.up_proj.linear', 'mlp.gate_proj.linear'],
['mlp.down_proj.linear']
]
# sequential = [
# ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
# ['self_attn.o_proj'],
# ['mlp.up_proj', 'mlp.gate_proj'],
# ['mlp.down_proj']
# ]
for i in range(len(layers)):
print(f'\nLayer {i}:', flush=True, end=' ')
layer = layers[i].to(dev)
full = find_qlayers(layer, layers=[torch.nn.Linear])
for names in sequential:
subset = {n: full[n] for n in names}
gptq = {}
for name in subset:
print(f'{name}', end=' ', flush=True)
layer_weight_bits = args.w_bits
layer_weight_sym = not(args.w_asym)
if 'lm_head' in name:
layer_weight_bits = 16
continue
gptq[name] = GPTQ(subset[name])
gptq[name].quantizer = WeightQuantizer()
gptq[name].quantizer.configure(
layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.gptq_mse
)
def add_batch(name):
def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data)
return tmp
handles = []
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
for h in handles:
h.remove()
for name in subset:
layer_w_groupsize = args.w_groupsize
gptq[name].fasterquant(
percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False
)
quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
gptq[name].free()
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
layers[i] = layer.cpu()
del layer
del gptq
torch.cuda.empty_cache()
inps, outs = outs, inps
model.config.use_cache = use_cache
cleanup_memory(verbose=True)
logging.info('-----GPTQ Quantization Done-----\n')
return quantizers
@torch.no_grad()
def rtn_fwrd(model, dev, args):
'''
From GPTQ repo
TODO: Make this function general to support both OPT and LLaMA models
'''
assert args.w_groupsize ==-1, "Groupsize not supported in RTN!"
layers = model.model.layers
torch.cuda.empty_cache()
quantizers = {}
for i in tqdm.tqdm(range(len(layers)), desc="(RtN Quant.) Layers"):
layer = layers[i].to(dev)
subset = find_qlayers(layer,
layers=[torch.nn.Linear])
for name in subset:
layer_weight_bits = args.w_bits
if 'lm_head' in name:
layer_weight_bits = 16
continue
quantizer = WeightQuantizer()
quantizer.configure(
layer_weight_bits, perchannel=True, sym=not(args.w_asym), mse=args.gptq_mse
)
W = subset[name].weight.data
w_dtype = W.dtype
quantizer.find_params(W)
subset[name].weight.data = quantizer.quantize(W).to(w_dtype)
quantizers['model.layers.%d.%s' % (i, name)] = quantizer.cpu()
layers[i] = layer.cpu()
torch.cuda.empty_cache()
del layer
cleanup_memory(verbose=True)
return quantizers