From edede40f7fb984f53b04ffd401cbe9f7ec3521d6 Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Wed, 20 Mar 2024 08:53:59 +0800 Subject: [PATCH] Fix WOQ int8 unpack weight (#1393) --- .../llm/quantization/nn/modules.py | 2 ++ .../transformers/llm/quantization/utils.py | 33 +++++++++++++++---- .../transformers/modeling/modeling_auto.py | 6 ++-- tests/CI/test_quantization.py | 2 +- tests/CI/test_weight_only.py | 10 ++++-- 5 files changed, 41 insertions(+), 12 deletions(-) diff --git a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py index 375e5c374cc..7b81f87d5d9 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py @@ -301,6 +301,8 @@ def recover_qparms(self): qzeros = torch.ops.bestlaop.acquire_woq_packw_info(self.weight, 10) if bits == 4: qzeros = qzeros // 16 + 8 + else: + qzeros = (qzeros.to(torch.int32) + 128).to(torch.uint8) else: qzeros = None diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index d8972c0c264..4e5bad8d3a1 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -54,24 +54,45 @@ def unpack_weight(qweight, scales, qzeros, q_config): + sym = q_config.sym bits = q_config.bits wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0) + zeros = torch.bitwise_right_shift( torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) ).to(torch.int16 if bits == 8 else torch.int8) torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) if bits == 8: - zeros = zeros.to(torch.int8) + zeros = zeros.to(torch.int8 if sym else torch.uint8) + # due to INC minus one zeros = zeros + 1 - zeros = zeros.reshape(scales.shape) + try: + zeros = zeros.reshape(scales.shape) + except: + # zeros and scales have different iteam numbers. + # remove 1 (due to 0 + 1 in line 68) + zeros = zeros[zeros !=1] + zeros = zeros.reshape(scales.shape) + + # due to INC asym return torch.uint8 but backend request int8, + # change it to int8 with offset 128 + if not sym and bits == 8: + zeros = (zeros.to(torch.int32) - 128).to(torch.int8) weight = torch.bitwise_right_shift( torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1) ).to(torch.int16 if bits == 8 else torch.int8) torch.bitwise_and(weight, (2**bits) - 1, out=weight) - if bits == 8: - weight = weight.to(torch.int8) + # due to INC add shift bias for sym + if sym: + shift_bias = 2 ** (bits - 1) + weight -= shift_bias + weight = weight.to(torch.int8 if sym else torch.uint8) + # due to INC asym return torch.uint8 but backend request int8, + # change it to int8 with offset 128 + if not sym: + weight = (weight.to(torch.int32) - 128). to(torch.int8) return weight, scales, zeros @@ -238,7 +259,7 @@ def _replace_linear( model._modules[name].requires_grad_(False) if device == "cpu" or device == torch.device("cpu") or device == "auto": if quantization_config.weight_dtype in \ - ["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int8", "int4_fullrange"]: + ["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int4_fullrange"]: model._modules[name].set_fp_weights_bias( module.weight.data, None if module.bias is None else module.bias.data, @@ -506,7 +527,7 @@ def default_calib_func(model): q_model = replace_linear(model, None, None, config, device=device) else: - if config.weight_dtype not in ["nf4", "fp4", "int8", "int4_fullrange"]: + if config.weight_dtype not in ["nf4", "fp4", "int4_fullrange"]: inc_model = inc_model.export_compressed_model(use_optimum_format=True) inc_model.eval() q_model = replace_linear(inc_model, None, None, config, device=device) diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index 65bc92097f7..f50d7e8e4ef 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -195,7 +195,7 @@ def save_low_bit( return if self.quantization_config.weight_dtype not in \ - ["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int8", "int4_fullrange"]: + ["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int4_fullrange"]: convert_model_to_public(self) os.makedirs(save_directory, exist_ok=True) # use transformers original `save_pretrained` function @@ -1171,7 +1171,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): else: model = model_class(config, *model_args, **kwargs) if config.quantization_config["weight_dtype"] not in \ - ["fp8_e5m2", "fp8_e4m3", "fp4", "nf4", "int8", "int4_fullrange"]: + ["fp8_e5m2", "fp8_e4m3", "fp4", "nf4", "int4_fullrange"]: model = build_woq_model(model, quantization_config) else: model = replace_linear( @@ -1222,7 +1222,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if config.quantization_config["weight_dtype"] not in \ - ["fp8_e5m2", "fp8_e4m3", "int8", "nf4", "fp4" "int4_fullrange"]: + ["fp8_e5m2", "fp8_e4m3", "nf4", "fp4" "int4_fullrange"]: model = replace_linear( model, quantization_config=quantization_config, diff --git a/tests/CI/test_quantization.py b/tests/CI/test_quantization.py index 8f295365c59..cc50da859b0 100644 --- a/tests/CI/test_quantization.py +++ b/tests/CI/test_quantization.py @@ -433,7 +433,7 @@ def test_quantization_for_llm(self): ) bit8_model.eval() output = bit8_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.1675747185945511, rel_tol=1e-04)) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.16759155690670013, rel_tol=1e-04)) # GPTQ woq_config = GPTQConfig(bits=4, diff --git a/tests/CI/test_weight_only.py b/tests/CI/test_weight_only.py index 3a65bd0b7ab..128ef3f70f8 100644 --- a/tests/CI/test_weight_only.py +++ b/tests/CI/test_weight_only.py @@ -29,8 +29,14 @@ Trainer ) from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM -from intel_extension_for_transformers.transformers.llm.quantization.nn.modules import QuantizedLinearQBits, QuantizedLoraLinearQBits -from intel_extension_for_transformers.transformers.llm.quantization.utils import convert_to_quantized_model, replace_linear +from intel_extension_for_transformers.transformers.llm.quantization.nn.modules import ( + QuantizedLinearQBits, + QuantizedLoraLinearQBits +) +from intel_extension_for_transformers.transformers.llm.quantization.utils import ( + convert_to_quantized_model, + replace_linear +) from intel_extension_for_transformers.transformers.llm.utils.generation import _beam_search, _greedy_search from intel_extension_for_transformers.transformers import RtnConfig