Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-Pick][Release2.6] Add GroupWiseQuant & AWQ & AutoClip #1821

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion paddleslim/quant/advanced/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from . import sample
from . import layerwise_quant_error
from . import utils_layers
from . import awq_search
from . import auto_clip

from .gptq import *
from .smooth import *
Expand All @@ -27,6 +29,8 @@
from .sample import *
from .layerwise_quant_error import *
from .utils_layers import *
from .awq_search import *
from .auto_clip import *

__all__ = []
__all__ += gptq.__all__
Expand All @@ -35,4 +39,6 @@
__all__ += piecewise_search.__all__
__all__ += sample.__all__
__all__ += layerwise_quant_error.__all__
__all__ += utils_layers.__all__
__all__ += utils_layers.__all__
__all__ += awq_search.__all__
__all__ += auto_clip.__all__
155 changes: 155 additions & 0 deletions paddleslim/quant/advanced/auto_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import numpy as np
from .utils import fake_quant
from .metrics import mse_loss
from paddle.distributed.fleet.meta_parallel import (
ColumnParallelLinear,
RowParallelLinear,
)
__all__ = ['AutoClip']

class AutoClip(nn.Layer):
"""
AutoClip from AWQ[https://arxiv.org/abs/2306.00978]
"""
def __init__(
self,
model,
weight_bits=8,
weight_quant_method='groupwise',
loss_function=mse_loss,
sample_function=None,
n_grid=20,
max_shrink=0.5,
n_sample_token=128,
group_size=-1,
):
super(AutoClip, self).__init__()
self.model = model
self.weight_bits = weight_bits
self.weight_method = weight_quant_method
self.loss_function = loss_function
self.n_grid = n_grid
self.max_shrink = max_shrink
self.n_sample_token = n_sample_token
self.bnt = (1 << (self.weight_bits - 1)) - 1
self.sampled_inputs = {}
self.sample_function = sample_function
self.group_size = group_size

self._apply_hook()

def _apply_hook(self):
self._forward_hook_list = []
for _, sub_layer in self.model.named_sublayers():
if type(sub_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]:
forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
self._forward_pre_hook)
self._forward_hook_list.append(forward_pre_hook_handle)

def _forward_pre_hook(self, layer, input):
self._sample_scale(input, layer.full_name())
return input

def _sample_scale(self, input, name):
input = input[0] if type(input) == tuple else input
input.stop_gradient = True
if name not in self.sampled_inputs:
self.sampled_inputs[name] = input
else:
if self.sample_function is not None:
self.sampled_inputs[name] = self.sample_function.sample(
input, self.sampled_inputs[name], name)
else:
self.sampled_inputs[name] = input


def auto_clip(self, group_size=128, oc_batch_size=1024):
"""
search clip scale for each layer and update the layer's weight
"""
for sub_name, sub_layer in self.model.named_sublayers():
name = sub_layer.full_name()
if name not in self.sampled_inputs:
continue
print('AutoClipping', sub_name, name)
weight = sub_layer.weight.cast('float16')
weight_t = paddle.transpose(weight, perm=[1, 0])
x = self.sampled_inputs[name].cast('float16')
x = x.reshape([-1, x.shape[-1]])
x = x.reshape([1, x.shape[0], -1, group_size])
x = x[:, 0::x.shape[1] // self.n_sample_token]
weight_t = weight_t.reshape([weight_t.shape[0], 1, -1, group_size])
# fast test
# oc_batch_size = weight_t.shape[0] // 4
oc_batch_size = oc_batch_size if weight_t.shape[0] % oc_batch_size == 0 else 128 # prevent OOM
assert weight_t.shape[0] % oc_batch_size == 0

w_all = weight_t
best_max_val_all = []

for i_b in range(weight_t.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]

org_max_val = w.abs().max(axis=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = paddle.ones_like(org_max_val, dtype='float16') * 1e9
org_out = (x * w).sum(axis=-1) # co, n_token, n_group
for i_s in range(int(self.max_shrink * self.n_grid)):
max_val = org_max_val * (1 - i_s / self.n_grid)
max_val_tmp = max_val
cur_w = paddle.where(w > max_val_tmp, max_val_tmp, w)
cur_w = paddle.where(cur_w < - max_val_tmp, - max_val_tmp, cur_w)
quant_dequant_weight = fake_quant(cur_w, method='abs_max', weight_bits=4)
cur_out = (x * quant_dequant_weight).sum(axis=-1)
# co, 1, n_group, 1
tmp = (cur_out - org_out).detach().clone()
err = paddle.pow(tmp, 2).mean(axis=1).reshape(min_errs.shape)
print('block {} search s {} err {}'.format(i_b, i_s, err.mean().item()))
del cur_w, cur_out, quant_dequant_weight, tmp
paddle.device.cuda.empty_cache()

cur_best_idx = paddle.where(err < min_errs)
if cur_best_idx[0].shape[0] != 0:
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)

del org_out, org_max_val, min_errs, best_max_val, err, cur_best_idx, max_val_tmp, max_val, w
paddle.device.cuda.empty_cache()

best_max_val = paddle.concat(best_max_val_all, axis=0)
best_max_val = paddle.squeeze(best_max_val, axis=1)
for param in sub_layer.parameters(include_sublayers=False):
if 'w_0' in param.name:
param_tmp = param.transpose(perm=[1, 0]).cast('float16')
tmp_shape = param_tmp.shape
param_tmp = param_tmp.reshape([best_max_val.shape[0], best_max_val.shape[1], -1])
best_max_val = paddle.tile(best_max_val, repeat_times=(1, 1, param_tmp.shape[-1]))
param_tmp = paddle.where(param_tmp > best_max_val, best_max_val, param_tmp)
param_tmp = paddle.where(param_tmp < - best_max_val, - best_max_val, param_tmp)
param_tmp = param_tmp.reshape(tmp_shape).cast(param.dtype)
param_tmp = param_tmp.transpose(perm=[1, 0])
paddle.assign(param_tmp, output=param)
del param_tmp
paddle.device.cuda.empty_cache()
break

del best_max_val, weight_t, x, weight, self.sampled_inputs[name], w_all, best_max_val_all
paddle.device.cuda.empty_cache()

78 changes: 78 additions & 0 deletions paddleslim/quant/advanced/awq_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from .utils import compute_scales
from .metrics import mse_loss
__all__ = ['AWQSearch']

class AWQSearch():
def __init__(self,
n_grid=20,
bits_length=4,
weight_quant_method='groupwise',
group_size=128,
loss_function=mse_loss):
'''
The implementation of AutoScale from AWQ(https://arxiv.org/pdf/2306.00978.pdf).
'''
self.n_grid = n_grid
self.bits_length = bits_length
self.weight_quant_method = weight_quant_method
self.bnt = (1 << (bits_length - 1)) - 1
self.group_size = group_size
self.loss_function = loss_function

def search(self, layer_name, sampled_input, act_abs_max, weight):
act = sampled_input
act.stop_gradient = True
print('[awq search] search input of %s' % layer_name)
dtype = weight.dtype
origin_out = paddle.matmul(act, weight)
best_error = float('inf')
best_ratio = -1
best_scales = None

for ratio in range(self.n_grid):
ratio = ratio * 1 / self.n_grid
act_abs_max_tmp = act_abs_max.detach().clone().cast('float32')
scales = paddle.clip(paddle.pow(act_abs_max_tmp, ratio), min=1e-4)
scales = scales / (scales.max() * scales.min()).sqrt()
scales = scales.cast(dtype)
new_weight = weight * scales.reshape([-1, 1])
new_act = act / scales
quant_scale = compute_scales(
new_weight, method=self.weight_quant_method, group_size=self.group_size)
if self.weight_quant_method == 'groupwise':
quant_scale = paddle.repeat_interleave(quant_scale.cast('float32'), self.group_size, 0).cast(dtype)
quant_weight = paddle.clip(
paddle.round(new_weight / quant_scale * self.bnt),
-self.bnt - 1, self.bnt)
quant_dequant_weight = quant_weight / self.bnt * quant_scale
new_out = paddle.matmul(new_act,
quant_dequant_weight)
loss = self.loss_function(origin_out, new_out).numpy()
is_best = loss < best_error
if is_best:
print('find better ratio: {}, loss: {}'.format(ratio, loss))
best_error = loss
best_ratio = ratio
best_scales = scales

if best_scales is None:
best_scales = paddle.ones(scales.shape, dtype=dtype)
print('Cannot find better ratio.')
else:
print('Best ratio :{}, minimal loss : {}.'.format(best_ratio, best_error))
return best_scales
39 changes: 26 additions & 13 deletions paddleslim/quant/advanced/piecewise_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(self,
search_scale_max=5.,
weight_quant_method='abs_max_channel_wise',
act_quant_method='abs_max',
use_clip=False,
search_clip=False,
loss_function=mse_loss):
'''
PieceWiseSearch provides to search k_piece, alpha and scale.
Expand Down Expand Up @@ -58,31 +60,36 @@ def __init__(self,
self.act_quant_method = act_quant_method
self.bnt = (1 << (bits_length - 1)) - 1
self.loss_function = loss_function
self.use_clip = use_clip
self.search_clip = search_clip

def search(self, layer_name, sampled_input, act_abs_max, weight):
act = sampled_input
act.stop_gradient = True
print('[smooth search] search input of %s' % layer_name)

dtype = weight.dtype
origin_out = paddle.matmul(act, weight)
w_abs_max = weight.abs().max(axis=-1, keepdim=True)
rw_abs_max = w_abs_max.reshape(act_abs_max.shape)
np_act_abs_max = np.array(act_abs_max)
np_rw_abs_max = np.array(rw_abs_max)


smooth_scale_out = None
global_loss = float('inf')
best_scale = None

for k_piece in range(1, self.k_piece + 1):
if self.search_clip:
piece_range = [1] + list(range(1, self.k_piece + 1))
else:
piece_range = list(range(1, self.k_piece + 1))

for k_idx, k_piece in enumerate(piece_range):
if not self.search_piece:
k_piece = self.k_piece
print('Search {} Piece'.format(k_piece))
centroids, labels = k_means(act_abs_max, k_piece)
piece = ['piece_{}'.format(a) for a in range(len(centroids))]
for i in range(len(centroids)):
# print('search for piece {}; centroids value is {}'.format(
# piece[i], centroids[centroids.argsort()[i]].numpy()))
print('search for piece {}; centroids value is {}'.format(
piece[i], float(centroids[centroids.argsort()[i: i + 1]].cast('float32'))))
alpha = self.search_alpha_min
alpha_max = self.search_scale_max if self.search_scale_max is not None else self.search_alpha_max
calibration_loss = float('inf')
Expand All @@ -104,12 +111,16 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
alpha = round(alpha, 2)

if alpha < 1:
s = (np.power(np_act_abs_max, alpha) / np.power(
np_rw_abs_max, 1. - alpha)).clip(min=1e-5)
s = paddle.to_tensor(s, dtype='float32')
act_abs_max_tmp = act_abs_max.detach().clone()
s = paddle.clip(paddle.pow(act_abs_max_tmp, alpha) / paddle.pow(
rw_abs_max, 1 - alpha), min=1e-5)

if self.use_clip or (k_piece == 1 and k_idx == 1 and self.search_clip):
s = paddle.clip(act_abs_max_tmp / paddle.max(act_abs_max / s), min=1)
del act_abs_max_tmp
smooth_scale = s * mask_for_search
else:
smooth_scale = alpha * mask_for_search
smooth_scale = paddle.to_tensor(alpha, dtype=dtype) * mask_for_search

if smooth_scale_out is not None:
mask_for_ones_new = paddle.where(
Expand Down Expand Up @@ -145,9 +156,10 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
calibration_loss = cur_loss
final_smooth_scale = smooth_scale
final_alpha = alpha
# print('Better alpha: {} loss: {}'.format(alpha, calibration_loss.cast('float32')))

# print("Layer {} Piece {}, loss: {}, alpha : {}".format(
# layer_name, piece[i], float(calibration_loss), final_alpha))
print("Layer {} Piece {}, loss: {}, alpha : {}".format(
layer_name, piece[i], float(calibration_loss.cast('float32')), final_alpha))
if smooth_scale_out is None:
smooth_scale_out = final_smooth_scale
else:
Expand All @@ -160,4 +172,5 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
print('Find Better K-Piece {}'.format(k_piece))
if not self.search_piece:
break

return best_scale
Loading