diff --git a/paddleslim/quant/advanced/__init__.py b/paddleslim/quant/advanced/__init__.py index 1f0744ec..2e779a6e 100644 --- a/paddleslim/quant/advanced/__init__.py +++ b/paddleslim/quant/advanced/__init__.py @@ -19,6 +19,8 @@ from . import sample from . import layerwise_quant_error from . import utils_layers +from . import awq_search +from . import auto_clip from .gptq import * from .smooth import * @@ -27,6 +29,8 @@ from .sample import * from .layerwise_quant_error import * from .utils_layers import * +from .awq_search import * +from .auto_clip import * __all__ = [] __all__ += gptq.__all__ @@ -35,4 +39,6 @@ __all__ += piecewise_search.__all__ __all__ += sample.__all__ __all__ += layerwise_quant_error.__all__ -__all__ += utils_layers.__all__ \ No newline at end of file +__all__ += utils_layers.__all__ +__all__ += awq_search.__all__ +__all__ += auto_clip.__all__ \ No newline at end of file diff --git a/paddleslim/quant/advanced/auto_clip.py b/paddleslim/quant/advanced/auto_clip.py new file mode 100644 index 00000000..69690111 --- /dev/null +++ b/paddleslim/quant/advanced/auto_clip.py @@ -0,0 +1,155 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from .utils import fake_quant +from .metrics import mse_loss +from paddle.distributed.fleet.meta_parallel import ( + ColumnParallelLinear, + RowParallelLinear, +) +__all__ = ['AutoClip'] + +class AutoClip(nn.Layer): + """ + AutoClip from AWQ[https://arxiv.org/abs/2306.00978] + """ + def __init__( + self, + model, + weight_bits=8, + weight_quant_method='groupwise', + loss_function=mse_loss, + sample_function=None, + n_grid=20, + max_shrink=0.5, + n_sample_token=128, + group_size=-1, + ): + super(AutoClip, self).__init__() + self.model = model + self.weight_bits = weight_bits + self.weight_method = weight_quant_method + self.loss_function = loss_function + self.n_grid = n_grid + self.max_shrink = max_shrink + self.n_sample_token = n_sample_token + self.bnt = (1 << (self.weight_bits - 1)) - 1 + self.sampled_inputs = {} + self.sample_function = sample_function + self.group_size = group_size + + self._apply_hook() + + def _apply_hook(self): + self._forward_hook_list = [] + for _, sub_layer in self.model.named_sublayers(): + if type(sub_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: + forward_pre_hook_handle = sub_layer.register_forward_pre_hook( + self._forward_pre_hook) + self._forward_hook_list.append(forward_pre_hook_handle) + + def _forward_pre_hook(self, layer, input): + self._sample_scale(input, layer.full_name()) + return input + + def _sample_scale(self, input, name): + input = input[0] if type(input) == tuple else input + input.stop_gradient = True + if name not in self.sampled_inputs: + self.sampled_inputs[name] = input + else: + if self.sample_function is not None: + self.sampled_inputs[name] = self.sample_function.sample( + input, self.sampled_inputs[name], name) + else: + self.sampled_inputs[name] = input + + + def auto_clip(self, group_size=128, oc_batch_size=1024): + """ + search clip scale for each layer and update the layer's weight + """ + for sub_name, sub_layer in self.model.named_sublayers(): + name = sub_layer.full_name() + if name not in self.sampled_inputs: + continue + print('AutoClipping', sub_name, name) + weight = sub_layer.weight.cast('float16') + weight_t = paddle.transpose(weight, perm=[1, 0]) + x = self.sampled_inputs[name].cast('float16') + x = x.reshape([-1, x.shape[-1]]) + x = x.reshape([1, x.shape[0], -1, group_size]) + x = x[:, 0::x.shape[1] // self.n_sample_token] + weight_t = weight_t.reshape([weight_t.shape[0], 1, -1, group_size]) + # fast test + # oc_batch_size = weight_t.shape[0] // 4 + oc_batch_size = oc_batch_size if weight_t.shape[0] % oc_batch_size == 0 else 128 # prevent OOM + assert weight_t.shape[0] % oc_batch_size == 0 + + w_all = weight_t + best_max_val_all = [] + + for i_b in range(weight_t.shape[0] // oc_batch_size): + w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size] + + org_max_val = w.abs().max(axis=-1, keepdim=True) # co, 1, n_group, 1 + best_max_val = org_max_val.clone() + min_errs = paddle.ones_like(org_max_val, dtype='float16') * 1e9 + org_out = (x * w).sum(axis=-1) # co, n_token, n_group + for i_s in range(int(self.max_shrink * self.n_grid)): + max_val = org_max_val * (1 - i_s / self.n_grid) + max_val_tmp = max_val + cur_w = paddle.where(w > max_val_tmp, max_val_tmp, w) + cur_w = paddle.where(cur_w < - max_val_tmp, - max_val_tmp, cur_w) + quant_dequant_weight = fake_quant(cur_w, method='abs_max', weight_bits=4) + cur_out = (x * quant_dequant_weight).sum(axis=-1) + # co, 1, n_group, 1 + tmp = (cur_out - org_out).detach().clone() + err = paddle.pow(tmp, 2).mean(axis=1).reshape(min_errs.shape) + print('block {} search s {} err {}'.format(i_b, i_s, err.mean().item())) + del cur_w, cur_out, quant_dequant_weight, tmp + paddle.device.cuda.empty_cache() + + cur_best_idx = paddle.where(err < min_errs) + if cur_best_idx[0].shape[0] != 0: + min_errs[cur_best_idx] = err[cur_best_idx] + best_max_val[cur_best_idx] = max_val[cur_best_idx] + best_max_val_all.append(best_max_val) + + del org_out, org_max_val, min_errs, best_max_val, err, cur_best_idx, max_val_tmp, max_val, w + paddle.device.cuda.empty_cache() + + best_max_val = paddle.concat(best_max_val_all, axis=0) + best_max_val = paddle.squeeze(best_max_val, axis=1) + for param in sub_layer.parameters(include_sublayers=False): + if 'w_0' in param.name: + param_tmp = param.transpose(perm=[1, 0]).cast('float16') + tmp_shape = param_tmp.shape + param_tmp = param_tmp.reshape([best_max_val.shape[0], best_max_val.shape[1], -1]) + best_max_val = paddle.tile(best_max_val, repeat_times=(1, 1, param_tmp.shape[-1])) + param_tmp = paddle.where(param_tmp > best_max_val, best_max_val, param_tmp) + param_tmp = paddle.where(param_tmp < - best_max_val, - best_max_val, param_tmp) + param_tmp = param_tmp.reshape(tmp_shape).cast(param.dtype) + param_tmp = param_tmp.transpose(perm=[1, 0]) + paddle.assign(param_tmp, output=param) + del param_tmp + paddle.device.cuda.empty_cache() + break + + del best_max_val, weight_t, x, weight, self.sampled_inputs[name], w_all, best_max_val_all + paddle.device.cuda.empty_cache() + diff --git a/paddleslim/quant/advanced/awq_search.py b/paddleslim/quant/advanced/awq_search.py new file mode 100644 index 00000000..55151c4e --- /dev/null +++ b/paddleslim/quant/advanced/awq_search.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import numpy as np +from .utils import compute_scales +from .metrics import mse_loss +__all__ = ['AWQSearch'] + +class AWQSearch(): + def __init__(self, + n_grid=20, + bits_length=4, + weight_quant_method='groupwise', + group_size=128, + loss_function=mse_loss): + ''' + The implementation of AutoScale from AWQ(https://arxiv.org/pdf/2306.00978.pdf). + ''' + self.n_grid = n_grid + self.bits_length = bits_length + self.weight_quant_method = weight_quant_method + self.bnt = (1 << (bits_length - 1)) - 1 + self.group_size = group_size + self.loss_function = loss_function + + def search(self, layer_name, sampled_input, act_abs_max, weight): + act = sampled_input + act.stop_gradient = True + print('[awq search] search input of %s' % layer_name) + dtype = weight.dtype + origin_out = paddle.matmul(act, weight) + best_error = float('inf') + best_ratio = -1 + best_scales = None + + for ratio in range(self.n_grid): + ratio = ratio * 1 / self.n_grid + act_abs_max_tmp = act_abs_max.detach().clone().cast('float32') + scales = paddle.clip(paddle.pow(act_abs_max_tmp, ratio), min=1e-4) + scales = scales / (scales.max() * scales.min()).sqrt() + scales = scales.cast(dtype) + new_weight = weight * scales.reshape([-1, 1]) + new_act = act / scales + quant_scale = compute_scales( + new_weight, method=self.weight_quant_method, group_size=self.group_size) + if self.weight_quant_method == 'groupwise': + quant_scale = paddle.repeat_interleave(quant_scale.cast('float32'), self.group_size, 0).cast(dtype) + quant_weight = paddle.clip( + paddle.round(new_weight / quant_scale * self.bnt), + -self.bnt - 1, self.bnt) + quant_dequant_weight = quant_weight / self.bnt * quant_scale + new_out = paddle.matmul(new_act, + quant_dequant_weight) + loss = self.loss_function(origin_out, new_out).numpy() + is_best = loss < best_error + if is_best: + print('find better ratio: {}, loss: {}'.format(ratio, loss)) + best_error = loss + best_ratio = ratio + best_scales = scales + + if best_scales is None: + best_scales = paddle.ones(scales.shape, dtype=dtype) + print('Cannot find better ratio.') + else: + print('Best ratio :{}, minimal loss : {}.'.format(best_ratio, best_error)) + return best_scales diff --git a/paddleslim/quant/advanced/piecewise_search.py b/paddleslim/quant/advanced/piecewise_search.py index 55678409..e326f2e5 100644 --- a/paddleslim/quant/advanced/piecewise_search.py +++ b/paddleslim/quant/advanced/piecewise_search.py @@ -31,6 +31,8 @@ def __init__(self, search_scale_max=5., weight_quant_method='abs_max_channel_wise', act_quant_method='abs_max', + use_clip=False, + search_clip=False, loss_function=mse_loss): ''' PieceWiseSearch provides to search k_piece, alpha and scale. @@ -58,31 +60,36 @@ def __init__(self, self.act_quant_method = act_quant_method self.bnt = (1 << (bits_length - 1)) - 1 self.loss_function = loss_function + self.use_clip = use_clip + self.search_clip = search_clip def search(self, layer_name, sampled_input, act_abs_max, weight): act = sampled_input act.stop_gradient = True print('[smooth search] search input of %s' % layer_name) - + dtype = weight.dtype origin_out = paddle.matmul(act, weight) w_abs_max = weight.abs().max(axis=-1, keepdim=True) rw_abs_max = w_abs_max.reshape(act_abs_max.shape) - np_act_abs_max = np.array(act_abs_max) - np_rw_abs_max = np.array(rw_abs_max) - + smooth_scale_out = None global_loss = float('inf') best_scale = None - for k_piece in range(1, self.k_piece + 1): + if self.search_clip: + piece_range = [1] + list(range(1, self.k_piece + 1)) + else: + piece_range = list(range(1, self.k_piece + 1)) + + for k_idx, k_piece in enumerate(piece_range): if not self.search_piece: k_piece = self.k_piece print('Search {} Piece'.format(k_piece)) centroids, labels = k_means(act_abs_max, k_piece) piece = ['piece_{}'.format(a) for a in range(len(centroids))] for i in range(len(centroids)): - # print('search for piece {}; centroids value is {}'.format( - # piece[i], centroids[centroids.argsort()[i]].numpy())) + print('search for piece {}; centroids value is {}'.format( + piece[i], float(centroids[centroids.argsort()[i: i + 1]].cast('float32')))) alpha = self.search_alpha_min alpha_max = self.search_scale_max if self.search_scale_max is not None else self.search_alpha_max calibration_loss = float('inf') @@ -104,12 +111,16 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): alpha = round(alpha, 2) if alpha < 1: - s = (np.power(np_act_abs_max, alpha) / np.power( - np_rw_abs_max, 1. - alpha)).clip(min=1e-5) - s = paddle.to_tensor(s, dtype='float32') + act_abs_max_tmp = act_abs_max.detach().clone() + s = paddle.clip(paddle.pow(act_abs_max_tmp, alpha) / paddle.pow( + rw_abs_max, 1 - alpha), min=1e-5) + + if self.use_clip or (k_piece == 1 and k_idx == 1 and self.search_clip): + s = paddle.clip(act_abs_max_tmp / paddle.max(act_abs_max / s), min=1) + del act_abs_max_tmp smooth_scale = s * mask_for_search else: - smooth_scale = alpha * mask_for_search + smooth_scale = paddle.to_tensor(alpha, dtype=dtype) * mask_for_search if smooth_scale_out is not None: mask_for_ones_new = paddle.where( @@ -145,9 +156,10 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): calibration_loss = cur_loss final_smooth_scale = smooth_scale final_alpha = alpha + # print('Better alpha: {} loss: {}'.format(alpha, calibration_loss.cast('float32'))) - # print("Layer {} Piece {}, loss: {}, alpha : {}".format( - # layer_name, piece[i], float(calibration_loss), final_alpha)) + print("Layer {} Piece {}, loss: {}, alpha : {}".format( + layer_name, piece[i], float(calibration_loss.cast('float32')), final_alpha)) if smooth_scale_out is None: smooth_scale_out = final_smooth_scale else: @@ -160,4 +172,5 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): print('Find Better K-Piece {}'.format(k_piece)) if not self.search_piece: break + return best_scale diff --git a/paddleslim/quant/advanced/smooth.py b/paddleslim/quant/advanced/smooth.py index e715788e..5e32435f 100644 --- a/paddleslim/quant/advanced/smooth.py +++ b/paddleslim/quant/advanced/smooth.py @@ -26,6 +26,8 @@ def __init__( model_config, alpha=0.5, smooth_all_linears=False, + start_sample_step=10000, + smooth_method='smoothquant', sample_function=None, search_function=None, ): ''' @@ -68,6 +70,8 @@ def __init__( self.smooth_all_linears = smooth_all_linears self.sample_function = sample_function self.search_function = search_function + self.start_sample_step = start_sample_step + self.smooth_method = smooth_method self.model.eval() self.step = 0 @@ -98,7 +102,6 @@ def _get_smooth_layers(self): self.ln_linear_dict, self.linear_ln_dict = get_ln_linear_info( self.layer_order, self.norm_flag, self.linear_flag, self.fused_qkv, self.parallel_ffn, self.skip_norm_list) - assert len(self.ln_linear_dict) > 0, 'No LN/Linear pair found' for key in self.ln_linear_dict: print('smooth pair LN {} : Linear {}'.format( @@ -147,29 +150,32 @@ def _forward_pre_hook(self, layer, input): def _sample_scale(self, input, ln_name): x = input[0] if type(input) == tuple else input x.stop_gradient = True - x_abs_max = x.abs().max(axis=1, keepdim=True) - x_abs_max = x_abs_max.max(axis=0) + + if self.smooth_method == 'smoothquant': + x_abs_max = x.abs().max(axis=1, keepdim=True) + x_abs_max = x_abs_max.max(axis=0) + elif self.smooth_method == 'awq': + x_abs_max = x.abs().reshape([-1, x.shape[-1]]) + x_abs_max = x_abs_max.mean(axis=0).reshape([1, -1]) + else: + raise NotImplementedError("To be implemented") if ln_name not in self.scale_dict: self.sampled_inputs[ln_name] = x self.scale_dict[ln_name] = x_abs_max else: - if self.sample_function is not None: + if self.sample_function is not None and self.step >= self.start_sample_step: self.sampled_inputs[ln_name] = self.sample_function.sample( x, self.sampled_inputs[ln_name], ln_name) else: self.sampled_inputs[ln_name] = x - tmp1 = paddle.concat([x_abs_max, self.scale_dict[ln_name]], axis=0) - self.scale_dict[ln_name] = tmp1.max(axis=0, keepdim=True) + if self.smooth_method == 'smoothquant': + tmp1 = paddle.concat([x_abs_max, self.scale_dict[ln_name]], axis=0) + self.scale_dict[ln_name] = tmp1.max(axis=0, keepdim=True) + elif self.smooth_method == 'awq': + tmp1 = paddle.concat([x_abs_max, self.scale_dict[ln_name]], axis=0) + self.scale_dict[ln_name] = tmp1.mean(axis=0, keepdim=True) - # per step print once - if self.print_step == self.step: - print('[Smooth] Step [{}]: {}. abs_min: {}, abs_max: {}'.format( - self.step, ln_name, - float(self.scale_dict[ln_name].cast("float32").min()), - float(self.scale_dict[ln_name].cast("float32").max()))) - if ln_name == list(self.linear_ln_dict.values())[-1]: - self.print_step += 1 def update_weight(self): @@ -181,24 +187,20 @@ def update_weight(self): if type(sub_layer) == ShiftSmoothHelpLayer: ln_name = layer_name if ln_name is not None: - act_abs_max = self.scale_dict[ln_name].cast("float32") - sampled_input = self.sampled_inputs[ln_name].cast("float32") + act_abs_max = self.scale_dict[ln_name].cast("float16") + sampled_input = self.sampled_inputs[ln_name].cast("float16") for param in sub_layer.parameters(include_sublayers=False): if 'w_0' in param.name: - weight = param.cast("float32") + # weight = param.cast("float32") if self.search_function is not None: s = self.search_function.search( - layer_name, sampled_input, act_abs_max, weight) + layer_name, sampled_input, act_abs_max, param.cast("float16")) else: - w_abs_max = weight.abs().max(axis=-1, keepdim=True) + w_abs_max = param.abs().max(axis=-1, keepdim=True) rw_abs_max = w_abs_max.reshape(act_abs_max.shape) - act_abs_max_np = act_abs_max.numpy() - weight_abs_max_np = rw_abs_max.numpy() - s = ( - np.power(act_abs_max_np, self.alpha) / np.power( - weight_abs_max_np, 1 - self.alpha)).clip( - min=1e-5) - s = paddle.to_tensor(s, dtype="float32") + act_abs_max_tmp = act_abs_max.detach().clone() + s = paddle.clip(paddle.pow(act_abs_max_tmp, self.alpha) / paddle.pow( + rw_abs_max, 1 - self.alpha), min=1e-5) self.smooth_scale_dict[ln_name] = s.cast(param.dtype) break @@ -273,4 +275,4 @@ def update_weight(self): def _remove_hook(self): for hook in self._forward_hook_list: hook.remove() - self._forward_hook_list = [] + self._forward_hook_list = [] \ No newline at end of file diff --git a/paddleslim/quant/advanced/utils.py b/paddleslim/quant/advanced/utils.py index 703fc5e1..ff77462b 100644 --- a/paddleslim/quant/advanced/utils.py +++ b/paddleslim/quant/advanced/utils.py @@ -38,7 +38,7 @@ def k_means(weight, n_clusters, init='k-means++', max_iter=300): return paddle.to_tensor(centroids.flatten()), paddle.to_tensor(labels) -def compute_scales(x, method='abs_max'): +def compute_scales(x, method='abs_max', group_size=-1): if method == 'abs_max': quant_scale = float(paddle.max(paddle.abs(x.flatten()))) quant_scale = 1e-8 if quant_scale == 0.0 else quant_scale @@ -52,8 +52,26 @@ def compute_scales(x, method='abs_max'): 0, dtype=x.dtype), paddle.to_tensor(1e-8, dtype=x.dtype), quant_scale) + elif method == 'groupwise': + input_shape = x.shape + input_processed = x.transpose([1, 0]).reshape( + [input_shape[1], input_shape[0] // group_size, group_size]) + quant_scale = paddle.max( + paddle.abs(input_processed), axis=2) + quant_scale = paddle.where(quant_scale == paddle.to_tensor(0, dtype=x.dtype), + paddle.to_tensor(1e-8, dtype=x.dtype), quant_scale) + quant_scale = quant_scale.transpose([1, 0]) + return quant_scale +def fake_quant(x, method='abs_max', weight_bits=8, group_size=-1): + bnt = (1 << (weight_bits - 1)) - 1 + quant_scale = compute_scales(x, method=method, group_size=group_size) + quant_value = paddle.clip( + paddle.round(x / quant_scale * bnt), -bnt - 1, bnt) + quant_dequant_value = quant_value / bnt * quant_scale + return quant_dequant_value + def find_parent_layer_and_sub_name(model, name): last_idx = 0 diff --git a/paddleslim/quant/observers/__init__.py b/paddleslim/quant/observers/__init__.py index 7ab3b723..0b7970ba 100644 --- a/paddleslim/quant/observers/__init__.py +++ b/paddleslim/quant/observers/__init__.py @@ -20,6 +20,7 @@ from .abs_max import AbsmaxObserver from .mse_weight import MSEChannelWiseWeightObserver from .abs_max_weight import AbsMaxChannelWiseWeightObserver +from .groupwise import GroupWiseWeightObserver __all__ = [ "HistObserver", @@ -31,4 +32,5 @@ "AbsmaxObserver", "MSEChannelWiseWeightObserver", "AbsMaxChannelWiseWeightObserver", + "GroupWiseWeightObserver" ] diff --git a/paddleslim/quant/observers/groupwise.py b/paddleslim/quant/observers/groupwise.py new file mode 100644 index 00000000..1db2067c --- /dev/null +++ b/paddleslim/quant/observers/groupwise.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from .channel_wise import ChannelWiseObserver +from paddle.quantization.factory import ObserverFactory + + +class GroupWiseWeightObserver(ObserverFactory): + r""" + It collects channel-wise maximum absolute values of target weights. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import AbsMaxChannelWiseWeightObserver + quanter = AbsMaxChannelWiseWeightObserver() + q_config = QuantConfig(activation=None, weight=quanter) + """ + + def __init__(self, quant_bits=8, group_size=128): + super(GroupWiseWeightObserver, self).__init__( + quant_bits=quant_bits, + group_size=group_size) + + def _get_class(self): + return GroupWiseWeightObserverLayer + + +class GroupWiseWeightObserverLayer(ChannelWiseObserver): + def __init__(self, layer, quant_bits=8, group_size=128): + super(GroupWiseWeightObserverLayer, self).__init__( + layer, + quant_bits=quant_bits, + sign=True, + symmetric=True, ) + self.quant_bits = quant_bits + self.group_size = group_size + self.qmin, self.qmax = self.qmin_qmax + self._layer = layer + self._max = None + self._scale = None + self._zero_point = None + + def forward(self, inputs): + self._max = self._cal_abs_max(inputs) + return inputs + + def _cal_abs_max(self, inputs): + """ Use group_size to group the input, then use the + absmax method to calculate the scale + """ + input_shape = inputs.shape + assert self.group_size == 64 or self.group_size == 128, \ + "group_size only support 64 or 128" + assert inputs.shape[0] % self.group_size == 0, \ + "group_size must be a factor of input channels" + assert len(inputs.shape) == 2, \ + "Currently only support 2D tensor" + input_processed = inputs.transpose([1, 0]).reshape( + [input_shape[1], input_shape[0] // self.group_size, self.group_size]) + + abs_max_values = paddle.max( + paddle.abs(input_processed), axis=2).cast("float32") + # "abs_max_values < 1e-8" in bfloat16 type? + abs_max_values = paddle.where(abs_max_values == np.float32(0), + np.float32(1e-8), abs_max_values) + abs_max_values = abs_max_values.transpose([1, 0]) + return abs_max_values + + def min_value(self) -> float: + return 0. + + def max_value(self) -> float: + return self._max + + def cal_thresholds(self): + """ Compute thresholds for MAX function. + """ + if self._scale is None: + self._scale = self._max + self._zero_point = paddle.zeros_like(self._scale) + + def scales(self): + """ Return output scales. + """ + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """ Return output zero points. + """ + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point