Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
improve sqconfig and add ut
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <chang1.wang@intel.com>
  • Loading branch information
changwangss committed Sep 18, 2023
1 parent 7ba2aed commit 205f8ec
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,23 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if calib_func is None:
from datasets import load_dataset
from torch.utils.data import DataLoader
calib_dataset = load_dataset("NeelNanda/pile-10k", split="train")
calib_dataset = quantization_config.calib_dataset
calib_iters = quantization_config.calib_iters
calib_dataset = load_dataset(calib_dataset, split="train")
calib_dataset = calib_dataset.shuffle(seed=42)

def tokenize_function(examples):
return quantization_config.tokenizer(examples["text"])
if 'prompt' in examples:
example = quantization_config.tokenizer(examples["prompt"])
elif 'text' in examples:
example = quantization_config.tokenizer(examples["text"])
elif 'code' in examples:
example = quantization_config.tokenizer(examples["code"])
else:
logger.error("Please check dataset prompt identifier," +
" NeelNanda/pile-10k is default used calibration dataset.")
exit(0)
return example

tokenized_dataset = calib_dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids"])
Expand Down Expand Up @@ -213,7 +225,7 @@ def default_calib_func(model):
past_key_values = generate_dummy_past_key_values(input_bs, model)
attention_mask = torch.ones(input_bs, input_len + 1)
attention_mask[:,0] = 0
if i >= 100:
if i >= calib_iters:
break
model(
input_ids=input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,31 @@
# limitations under the License.
"""Configs for intel extension for transformers."""

from dataclasses import dataclass
from typing import Optional, Any
from dataclasses import dataclass, field
from typing import Any, Optional

from transformers import BitsAndBytesConfig


@dataclass
class WeightOnlyQuantizationConfig:
algorithm: str = 'RTN'
algorithm: str = "RTN"
bits: int = 8
group_size: int = -1
scheme: str = 'sym'
scheme: str = "sym"
enable_full_range: bool = True


@dataclass
class AMPConfig:
dtype: str = 'bfloat16'
dtype: str = "bfloat16"

@dataclass
class SmoothQuantConfig:
tokenizer: Any = None
calib_func: Any = None
calib_dataset: str = "NeelNanda/pile-10k"
calib_iters: int = 100
alpha: float = 0.5
op_type_dict: dict = None
excluded_precisions: dict = None

op_type_dict: dict = None
excluded_precisions: list = field(default_factory=list)
46 changes: 46 additions & 0 deletions tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,52 @@ def test_bf16_onnx(self):
self.assertEqual(tensor.data_type, TensorProto.BFLOAT16)
break

def test_quantization_for_llm(self):
model_name_or_path = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
from intel_extension_for_transformers.transformers import (
AMPConfig,
WeightOnlyQuantizationConfig,
SmoothQuantConfig,
BitsAndBytesConfig

)
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
dummy_input = fp32_model.dummy_inputs["input_ids"]

# smooth-quant
sq_config = SmoothQuantConfig(
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
calib_iters=5
)
q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=sq_config
)
self.assertTrue(isinstance(q_model, torch.jit.ScriptModule))
# weight-only
woq_config = WeightOnlyQuantizationConfig()
woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=woq_config
)
output = woq_model(dummy_input)
self.assertTrue(float(output[0][0][0][0]), -7.139640808105469)
# amp
amp_config = AMPConfig()
amp_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=amp_config
)
output = amp_model(dummy_input)
self.assertTrue(float(output[0][0][0][0]), -7.347761154174805)


# bitsandbytes
bab_config = BitsAndBytesConfig()
bab_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=bab_config
)
output = bab_model(dummy_input)
self.assertTrue(float(output[0][0][0][0]), -7.347761154174805)

if __name__ == "__main__":
unittest.main()

0 comments on commit 205f8ec

Please sign in to comment.