diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_causal.py b/intel_extension_for_transformers/transformers/modeling/modeling_causal.py index e090d092ad2..92686191a28 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_causal.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_causal.py @@ -176,11 +176,23 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): if calib_func is None: from datasets import load_dataset from torch.utils.data import DataLoader - calib_dataset = load_dataset("NeelNanda/pile-10k", split="train") + calib_dataset = quantization_config.calib_dataset + calib_iters = quantization_config.calib_iters + calib_dataset = load_dataset(calib_dataset, split="train") calib_dataset = calib_dataset.shuffle(seed=42) def tokenize_function(examples): - return quantization_config.tokenizer(examples["text"]) + if 'prompt' in examples: + example = quantization_config.tokenizer(examples["prompt"]) + elif 'text' in examples: + example = quantization_config.tokenizer(examples["text"]) + elif 'code' in examples: + example = quantization_config.tokenizer(examples["code"]) + else: + logger.error("Please check dataset prompt identifier," + + " NeelNanda/pile-10k is default used calibration dataset.") + exit(0) + return example tokenized_dataset = calib_dataset.map(tokenize_function, batched=True) tokenized_dataset.set_format(type="torch", columns=["input_ids"]) @@ -213,7 +225,7 @@ def default_calib_func(model): past_key_values = generate_dummy_past_key_values(input_bs, model) attention_mask = torch.ones(input_bs, input_len + 1) attention_mask[:,0] = 0 - if i >= 100: + if i >= calib_iters: break model( input_ids=input_ids, diff --git a/intel_extension_for_transformers/transformers/utils/quantization_config.py b/intel_extension_for_transformers/transformers/utils/quantization_config.py index 512322825dc..1c590609055 100644 --- a/intel_extension_for_transformers/transformers/utils/quantization_config.py +++ b/intel_extension_for_transformers/transformers/utils/quantization_config.py @@ -16,28 +16,31 @@ # limitations under the License. """Configs for intel extension for transformers.""" -from dataclasses import dataclass -from typing import Optional, Any +from dataclasses import dataclass, field +from typing import Any, Optional + from transformers import BitsAndBytesConfig @dataclass class WeightOnlyQuantizationConfig: - algorithm: str = 'RTN' + algorithm: str = "RTN" bits: int = 8 group_size: int = -1 - scheme: str = 'sym' + scheme: str = "sym" enable_full_range: bool = True + @dataclass class AMPConfig: - dtype: str = 'bfloat16' + dtype: str = "bfloat16" @dataclass class SmoothQuantConfig: tokenizer: Any = None calib_func: Any = None + calib_dataset: str = "NeelNanda/pile-10k" + calib_iters: int = 100 alpha: float = 0.5 - op_type_dict: dict = None - excluded_precisions: dict = None - + op_type_dict: dict = None + excluded_precisions: list = field(default_factory=list) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 08aa2e504bc..49031214936 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -286,6 +286,52 @@ def test_bf16_onnx(self): self.assertEqual(tensor.data_type, TensorProto.BFLOAT16) break + def test_quantization_for_llm(self): + model_name_or_path = "facebook/opt-125m" + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + from intel_extension_for_transformers.transformers import ( + AMPConfig, + WeightOnlyQuantizationConfig, + SmoothQuantConfig, + BitsAndBytesConfig + + ) + from intel_extension_for_transformers.transformers import AutoModelForCausalLM + fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + dummy_input = fp32_model.dummy_inputs["input_ids"] + + # smooth-quant + sq_config = SmoothQuantConfig( + tokenizer=tokenizer, # either two of one, tokenizer or calib_func + calib_iters=5 + ) + q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, + quantization_config=sq_config + ) + self.assertTrue(isinstance(q_model, torch.jit.ScriptModule)) + # weight-only + woq_config = WeightOnlyQuantizationConfig() + woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, + quantization_config=woq_config + ) + output = woq_model(dummy_input) + self.assertTrue(float(output[0][0][0][0]), -7.139640808105469) + # amp + amp_config = AMPConfig() + amp_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, + quantization_config=amp_config + ) + output = amp_model(dummy_input) + self.assertTrue(float(output[0][0][0][0]), -7.347761154174805) + + + # bitsandbytes + bab_config = BitsAndBytesConfig() + bab_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, + quantization_config=bab_config + ) + output = bab_model(dummy_input) + self.assertTrue(float(output[0][0][0][0]), -7.347761154174805) if __name__ == "__main__": unittest.main()