From e1f4666d1ad6fb73ab9ec2235d80f035f808edf4 Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Tue, 23 Apr 2024 15:12:49 +0800 Subject: [PATCH] Add StaticQuantConfig (#1501) --- .../transformers/__init__.py | 1 + .../transformers/modeling/modeling_auto.py | 177 +++++++++++++++++- .../transformers/utils/__init__.py | 1 + .../transformers/utils/config.py | 141 +++++++++----- tests/CI/test_quantization.py | 21 ++- 5 files changed, 293 insertions(+), 48 deletions(-) diff --git a/intel_extension_for_transformers/transformers/__init__.py b/intel_extension_for_transformers/transformers/__init__.py index 9bb247f7dd0..2d1626be35d 100644 --- a/intel_extension_for_transformers/transformers/__init__.py +++ b/intel_extension_for_transformers/transformers/__init__.py @@ -44,6 +44,7 @@ MixedPrecisionConfig, BitsAndBytesConfig, SmoothQuantConfig, + StaticQuantConfig, RtnConfig, AwqConfig, TeqConfig, diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index 1d07d945a42..043739d3309 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -42,6 +42,7 @@ BitsAndBytesConfig, MixedPrecisionConfig, SmoothQuantConfig, + StaticQuantConfig, RtnConfig, AwqConfig, TeqConfig, @@ -71,6 +72,7 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear +from neural_compressor.model.torch_model import PyTorchFXModel from threading import Thread from transformers.configuration_utils import PretrainedConfig from transformers import AutoConfig @@ -211,7 +213,14 @@ def save_low_bit( f"Provided path ({save_directory}) should be a directory, not a file" ) return - + if isinstance(self, PyTorchFXModel): + self.quantization_config.save_pretrained(save_directory, **kwargs) + self.model.config.quantization_config = self.quantization_config + self.model.config.save_pretrained(save_directory) + weights_file = os.path.join( + os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME) + torch.save(self.quantized_state_dict(), weights_file) + return convert_model_to_public(self) os.makedirs(save_directory, exist_ok=True) # use transformers original `save_pretrained` function @@ -403,7 +412,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): "Quantization_config loading failed. If you want to load saved " "low bit model, please check your quantizate_config.json." ) - elif use_neural_speed: + elif use_neural_speed and not config.quantization_config["quant_method"] == "static": if not os.path.exists(pretrained_model_name_or_path): from huggingface_hub import snapshot_download pretrained_model_name_or_path = snapshot_download(repo_id=pretrained_model_name_or_path, @@ -963,6 +972,157 @@ def calib_func(model): ), ) logger.info("SmoothQuant done.") + elif isinstance(quantization_config, StaticQuantConfig): + if quantization_config.backend == "ipex": + try: + import intel_extension_for_pytorch as ipex + except ImportError: + logger.warning( + "Please install Intel Extension for PyTorch to accelerate the model inference." + ) + config.torchscript = True + assert quantization_config.example_inputs is not None, \ + "Please provide example_inputs for IPEX static quantization." + + model = cls.ORIG_MODEL.from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + low_cpu_mem_usage=True, + torch_dtype=torch.float, + **kwargs, + ) + + if ( + not torch.cuda.is_available() + or device_map == "cpu" + or device_map == torch.device("cpu") + ) and model.config.model_type == "chatglm": + model = model.float() + model.eval() + logger.info("Applying StaticQuant.") + # calibration function + calib_func = quantization_config.calib_func + tokenizer = quantization_config.tokenizer + if calib_func is None: + if quantization_config.tokenizer is None: + logger.error( + "Please provide the tokenizer or provide calib_func directly," + + " the following is how to get tokenizer. \n" + + " from transformer import AutoTokenizer \n" + + " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n" + ) + exit(0) + + from datasets import load_dataset + from torch.utils.data import DataLoader + + calib_dataset = quantization_config.calib_dataset + calib_shuffle = quantization_config.calib_shuffle + calib_iters = quantization_config.calib_iters + calib_padding = quantization_config.calib_padding + calib_len = quantization_config.calib_len + calib_pad_val = quantization_config.calib_pad_val + from torch.nn.functional import pad + + calib_dataset = load_dataset( + calib_dataset, + split=( + "test" + if calib_dataset in ["mbpp", "openai_humaneval"] + else "train" + ), + ) + if calib_shuffle: + calib_dataset = calib_dataset.shuffle(seed=42) + + def tokenize_function(examples): + if "code" in examples: + example = tokenizer(examples["code"]) + elif "prompt" in examples: + example = tokenizer(examples["prompt"]) + elif "text" in examples: + example = tokenizer(examples["text"]) + else: + logger.error( + "Please check dataset prompt identifier," + + " NeelNanda/pile-10k is default used calibration dataset." + ) + exit(0) + return example + + def collate_batch(batch): + input_ids_padded = [] + last_ind = [] + for text in batch: + input_ids = text["input_ids"] + if not calib_padding: + input_ids = ( + input_ids[: int(calib_len)] + if len(input_ids) > int(calib_len) + else input_ids + ) # no_padding + else: + pad_len = calib_len - input_ids.shape[0] + input_ids = pad( + input_ids, (0, pad_len), value=calib_pad_val + ) + + last_ind.append(input_ids.shape[0] - 1) + input_ids_padded.append(input_ids) + + return ( + { + "input_ids": torch.vstack(input_ids_padded), + }, + torch.tensor(last_ind), + ) + + + tokenized_dataset = calib_dataset.map(tokenize_function, batched=True) + tokenized_dataset.set_format(type="torch", columns=["input_ids"]) + calib_dataloader = DataLoader( + tokenized_dataset, + batch_size=1, + shuffle=False, + collate_fn=collate_batch, + ) + + def calib_func(model): + with torch.no_grad(): + for i, (inputs, last_ind) in enumerate(calib_dataloader): + if i >= calib_iters: + break + model(**inputs) + + logger.info( + "The default calibration function is used, " + + "the calibration dataset is NeelNanda/pile-10k, " + + "batchsize is 1 and calibration iteration is 100." + ) + calib_func = calib_func + + + # call inc static quant + from neural_compressor import PostTrainingQuantConfig, quantization + + conf = PostTrainingQuantConfig( + backend=quantization_config.backend, # default is ipex + excluded_precisions=quantization_config.excluded_precisions, + op_type_dict=quantization_config.op_type_dict, + op_name_dict=quantization_config.op_name_dict, + example_inputs=quantization_config.example_inputs, + ) + model = quantization.fit( + model, + conf, + calib_func=calib_func, + ) + model.save_pretrained = types.MethodType(save_low_bit, model) + quantization_config.remove_redundant_parameters() + model.quantization_config = quantization_config + logger.info("StaticQuant done.") + return model else: if use_neural_speed: logger.info("Using Neural Speed with FP32 model dtype.") @@ -1093,6 +1253,8 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): quantization_config = GPTQConfig.from_dict(quantization_config) elif quantization_config["quant_method"] == "autoround": quantization_config = AutoRoundConfig.from_dict(quantization_config) + elif quantization_config["quant_method"] == "static": + quantization_config = StaticQuantConfig.from_dict(quantization_config) assert ( quantization_config is not None ), "Detect this model is not a low-bit model." @@ -1336,6 +1498,16 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # by checking its first weights entry that is of a floating type # - we assume all floating dtype weights are of the same dtype # we also may have config.torch_dtype available, but we won't rely on it till v5 + # Pretrained Model + if quantization_config.quant_method == "static": + model = model_class(config, *model_args, **kwargs) + from neural_compressor.utils.pytorch import load + weights_file = os.path.join( + os.path.abspath(os.path.expanduser(pretrained_model_name_or_path)), WEIGHTS_NAME) + q_model = load(weights_file, model, dataloader=None) + del model + return q_model + dtype_orig = None if torch_dtype is not None: if isinstance(torch_dtype, str): @@ -1378,7 +1550,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): if quantization_config.weight_dtype is None: quantization_config.weight_dtype = "int4_clip" - # Pretrained Model init_contexts = [no_init_weights(_enable=_fast_init)] init_contexts.append(init_empty_weights()) diff --git a/intel_extension_for_transformers/transformers/utils/__init__.py b/intel_extension_for_transformers/transformers/utils/__init__.py index c27ec590ba7..e779cec8f29 100644 --- a/intel_extension_for_transformers/transformers/utils/__init__.py +++ b/intel_extension_for_transformers/transformers/utils/__init__.py @@ -20,6 +20,7 @@ MixedPrecisionConfig, BitsAndBytesConfig, SmoothQuantConfig, + StaticQuantConfig, SparsityConfig, RtnConfig, AwqConfig, diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 682102950a8..851c4c44bba 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -32,31 +32,24 @@ class MixedPrecisionConfig: dtype: str = "bfloat16" +if transformers.__version__ >= "4.32.0": + from transformers.utils.quantization_config import QuantizationConfigMixin + QuantizationConfig = QuantizationConfigMixin +else: + QuantizationConfig = PretrainedConfig +from enum import Enum -@dataclass -class SmoothQuantConfig: - backend: str = "ipex" - ipex_opt_llm: bool = None - tokenizer: Any = None - calib_func: Any = None - calib_dataset: str = "NeelNanda/pile-10k" - calib_shuffle: bool = True - calib_iters: int = 100 - calib_padding: bool = False - calib_len: int = 512 - calib_pad_val: int = 1 - alpha: float = 0.5 - op_type_dict: dict = None - op_name_dict: dict = None - excluded_precisions: list = field(default_factory=list) - example_inputs: Any = None - num_beams: int = 1 - recipes: dict = field( - default_factory=lambda: { - "smooth_quant": True, - "smooth_quant_args": {"alpha": 0.5}, - } - ) + +class QuantizationMethod(str, Enum): + BITS_AND_BYTES = "bitsandbytes" + GPTQ = "gptq" + AWQ = "awq" + AQLM = "aqlm" + RTN = "rtn" + AUTOROUND = "autoround" + TEQ = "teq" + STATIC = "static" + SmoothQuant = "sq" class SparsityConfig(PretrainedConfig): @@ -241,24 +234,6 @@ def get_config_dict( pretrained_model_name_or_path, _configuration_file=SPARSITY_CONFIG, **kwargs ) -if transformers.__version__ >= "4.32.0": - from transformers.utils.quantization_config import QuantizationConfigMixin - QuantizationConfig = QuantizationConfigMixin -else: - QuantizationConfig = PretrainedConfig -from enum import Enum - - -class QuantizationMethod(str, Enum): - BITS_AND_BYTES = "bitsandbytes" - GPTQ = "gptq" - AWQ = "awq" - AQLM = "aqlm" - RTN = "rtn" - AUTOROUND = "autoround" - TEQ = "teq" - - class ITREXQuantizationConfigMixin(QuantizationConfig): """Mixin class for quantization config.""" @@ -561,7 +536,8 @@ def remove_redundant_parameters(self): remove_parameters = ["calib_dataloader", "dataset", "calib_func", "calib_iters", "calib_len", "double_quant_scale_dtype", "use_double_quant", "mse_range", "scheme", "tokenizer", "use_ggml", "use_neural_speed", "use_quant", "layer_wise", "blocksize", "nsamples", "max_input_length", "static_groups", - "lr", "minmax_lr", "iters", "use_quant_input", "device"] + "lr", "minmax_lr", "iters", "use_quant_input", "device", "calib_dataset", "calib_pad_val", "calib_shuffle", + "calib_padding", "example_inputs", "excluded_precisions", "op_name_dict", "op_type_dict"] for parameter in remove_parameters: if hasattr(self, parameter): delattr(self, parameter) @@ -624,6 +600,85 @@ def get_config_dict( pretrained_model_name_or_path, _configuration_file=cf, **kwargs ) +class StaticQuantConfig(ITREXQuantizationConfigMixin): + def __init__( + self, + backend="default", + tokenizer=None, + calib_dataset="NeelNanda/pile-10k", + calib_dataloader=None, + calib_func=None, + calib_shuffle=True, + calib_iters=100, + calib_padding=False, + calib_len=512, + calib_pad_val=1, + op_name_dict=None, + op_type_dict=None, + excluded_precisions=[], + example_inputs=None, + **kwargs, + ): + self.quant_method = QuantizationMethod.STATIC + self.backend = backend + self.tokenizer = tokenizer + self.calib_dataset = calib_dataset + self.calib_dataloader = calib_dataloader + self.calib_func = calib_func + self.calib_shuffle = calib_shuffle + self.calib_iters = calib_iters + self.calib_padding = calib_padding + self.calib_len = calib_len + self.calib_pad_val = calib_pad_val + self.op_name_dict = op_name_dict + self.op_type_dict = op_type_dict + self.excluded_precisions = excluded_precisions + self.example_inputs = example_inputs + +class SmoothQuantConfig(StaticQuantConfig): + def __init__( + self, + backend="ipex", + tokenizer=None, + calib_dataset="NeelNanda/pile-10k", + calib_dataloader=None, + calib_func=None, + calib_shuffle=True, + calib_iters=100, + calib_padding=False, + calib_len=512, + calib_pad_val=1, + op_name_dict=None, + op_type_dict=None, + excluded_precisions=[], + example_inputs=None, + ipex_opt_llm=None, + alpha=0.5, + num_beams=1, + recipes={"smooth_quant": True, "smooth_quant_args":{"alpha":0.5}}, + **kwargs, + ): + super().__init__( + backend=backend, + tokenizer=tokenizer, + calib_dataset=calib_dataset, + calib_dataloader=calib_dataloader, + calib_func=calib_func, + calib_shuffle=calib_shuffle, + calib_iters=calib_iters, + calib_padding=calib_padding, + calib_len=calib_len, + calib_pad_val=calib_pad_val, + op_name_dict=op_name_dict, + op_type_dict=op_type_dict, + excluded_precisions=excluded_precisions, + example_inputs=example_inputs, + ) + self.quant_method = QuantizationMethod.SmoothQuant + self.ipex_opt_llm = ipex_opt_llm + self.alpha = alpha + self.num_beams = num_beams + self.recipes = recipes class RtnConfig(ITREXQuantizationConfigMixin): def __init__( diff --git a/tests/CI/test_quantization.py b/tests/CI/test_quantization.py index cc50da859b0..c9d40746799 100644 --- a/tests/CI/test_quantization.py +++ b/tests/CI/test_quantization.py @@ -316,6 +316,7 @@ def test_quantization_for_llm(self): from intel_extension_for_transformers.transformers import ( MixedPrecisionConfig, SmoothQuantConfig, + StaticQuantConfig, RtnConfig, AwqConfig, TeqConfig, @@ -326,7 +327,23 @@ def test_quantization_for_llm(self): from intel_extension_for_transformers.transformers import AutoModelForCausalLM fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, use_neural_speed=False) dummy_input = fp32_model.dummy_inputs["input_ids"] - # SQ + # Static quant + sq_config = StaticQuantConfig( + tokenizer=tokenizer, # either two of one, tokenizer or calib_func + calib_iters=2, + ) + q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, + quantization_config=sq_config, + ) + q_model.eval() + output = q_model(dummy_input) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.17378684878349304, rel_tol=1e-04)) + q_model.save_pretrained("./saved_results") + loading_model = AutoModelForCausalLM.from_pretrained("./saved_results") + loading_model.eval() + output = loading_model(dummy_input) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.17378684878349304, rel_tol=1e-04)) + # Smoothquant sq_config = SmoothQuantConfig( tokenizer=tokenizer, # either two of one, tokenizer or calib_func calib_iters=2, @@ -338,7 +355,7 @@ def test_quantization_for_llm(self): ) self.assertTrue(isinstance(q_model.model, torch.jit.ScriptModule)) - # SQ auto + # Smoothquant auto recipes = { "smooth_quant": True, "smooth_quant_args": { "alpha": "auto", "auto_alpha_args":{"alpha_max": 0.6,