From 528d7de6fdc42185d98d147c9e10d3f56daf8090 Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Wed, 20 Mar 2024 08:54:17 +0800 Subject: [PATCH] Fix gptq desc_act and static_group (#1395) --- .../llm/quantization/nn/modules.py | 41 ++++++++++--------- .../transformers/modeling/modeling_auto.py | 2 +- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py index 7b81f87d5d9..0d85a3b9f86 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py @@ -169,22 +169,27 @@ def set_weights_bias( q_config, bias=None, ): - if q_config.quant_method.value == "gptq" and ( - q_config.desc_act and not q_config.static_groups - ): - int_weight2 = int_weight.clone() - group_size = q_config.group_size - group_dict = {} - for i in range(len(g_idx)): - group_idx = g_idx[i].item() - if group_idx not in group_dict: - target_idx = group_idx * group_size - group_dict[group_idx] = 0 + + if q_config.quant_method.value == "gptq": + if q_config.desc_act: + if not q_config.static_groups: + int_weight2 = int_weight.clone() + group_size = q_config.group_size + group_dict = {} + for i in range(len(g_idx)): + group_idx = g_idx[i].item() + if group_idx not in group_dict: + target_idx = group_idx * group_size + group_dict[group_idx] = 0 + else: + group_dict[group_idx] = group_dict[group_idx] + 1 + target_idx = group_idx * group_size + group_dict[group_idx] + int_weight2[target_idx] = int_weight[i] + int_weight = int_weight2 else: - group_dict[group_idx] = group_dict[group_idx] + 1 - target_idx = group_idx * group_size + group_dict[group_idx] - int_weight2[target_idx] = int_weight[i] - int_weight = int_weight2 + g_idx = torch.empty(0, dtype=torch.int32) + else: + g_idx = torch.empty(0, dtype=torch.int32) if q_config.bits == 4: int_weight = (int_weight - 8) * 16 @@ -194,11 +199,7 @@ def set_weights_bias( if q_config.sym: gptq_zeros = torch.empty(0, dtype=torch.int8) - if ( - q_config.quant_method.value != "gptq" - or q_config.static_groups - or (not q_config.desc_act) - ): + if q_config.quant_method.value != "gptq": g_idx = torch.empty(0, dtype=torch.int32) packw = torch.ops.bestlaop.woq_packq( diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index f50d7e8e4ef..ec7d66c35f7 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -67,7 +67,6 @@ convert_to_quantized_model, replace_linear, ) -from ..llm.quantization.nn.modules import QuantizedLinearQBits from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear from transformers.configuration_utils import PretrainedConfig from transformers import AutoConfig @@ -83,6 +82,7 @@ def recover_export_model(model, current_key_name=None): Return optimum format model. """ + from ..llm.quantization.nn.modules import QuantizedLinearQBits for name, module in model.named_children(): if current_key_name is None: current_key_name = []