Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Enable text-generation with new API #318

Merged
merged 21 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer, PretrainedConfig
from transformers.utils import check_min_version
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
import transformers
import numpy as np
from itertools import chain
from optimum.utils import NormalizedConfigManager
from optimum.intel.generation.modeling import TSModelForCausalLM


parser = argparse.ArgumentParser()
Expand All @@ -34,13 +34,10 @@
)
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
parser.add_argument("--quantize", action="store_true")
parser.add_argument("--ipex", action="store_true")
parser.add_argument("--sq", action="store_true")
parser.add_argument("--alpha", default="auto", help="Smooth quant parameter.")
parser.add_argument(
"--pad_max_length", default=512, type=int, help="Pad input ids to max length."
)
parser.add_argument("--calib_iters", default=512, type=int, help="calibration iters.")
parser.add_argument("--int8", action="store_true")
parser.add_argument(
"--int8_bf16_mixed",
Expand All @@ -58,198 +55,93 @@
parser.add_argument("--tasks", nargs='+', default=["winogrande", "copa", "piqa", "rte", "hellaswag", \
"openbookqa", "lambada_openai", "lambada_standard", "wikitext"], type=str, \
help="tasks list for accuracy validation")

args = parser.parse_args()

calib_size = 1

# model
config = AutoConfig.from_pretrained(
args.model,
torchscript=True
if args.ipex
else False, # torchscript will force `return_dict=False` to avoid jit errors
use_cache=True, # to use kv cache.
trust_remote_code=args.trust_remote_code,
revision=args.revision
)
args.model,
torchscript=True
if args.quantize
else False, # torchscript will force `return_dict=False` to avoid jit errors
use_cache=True, # to use kv cache.
trust_remote_code=args.trust_remote_code,
revision=args.revision,
)

# transformers version >= 4.32.0 contained the mpt modeling definition.
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
if config.model_type == "mpt":
check_min_version("4.32.0")

user_model = AutoModelForCausalLM.from_pretrained(
args.model,
config=config
)

# tokenizer
if config.model_type == "llama":
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(args.model)
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(args.model)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)

# to channels last
user_model = user_model.to(memory_format=torch.channels_last)
user_model.eval()

if args.ipex:
import intel_extension_for_pytorch as ipex
from optimum.intel.generation.modeling import TSModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)

# quantize
if args.quantize:
def generate_dummy_past_key_values(input_bs, user_model):
normalized_config = NormalizedConfigManager.get_normalized_config_class(
user_model.config.model_type
)(user_model.config)
nb_pkv = 2
num_layers = normalized_config.num_layers
num_attention_heads = normalized_config.num_attention_heads
hidden_size = normalized_config.hidden_size
d_k = hidden_size // num_attention_heads

if user_model.config.model_type == "bloom":
pkv = ()
for nb_pkv in range(nb_pkv):
if nb_pkv % 2 == 0:
new_shape = [input_bs * num_attention_heads, d_k, 1]
else:
new_shape = [input_bs * num_attention_heads, 1, d_k]
pkv = pkv + (torch.ones(size=new_shape),)
else:
new_shape = [input_bs, num_attention_heads, 1, d_k]
dummy_tensor = torch.ones(size=new_shape)
pkv = tuple(dummy_tensor for _ in range(nb_pkv))
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))
return past_key_values

class Evaluator:
def __init__(
self,
dataset,
tokenizer,
batch_size=8,
pad_val=1,
pad_max=512,
is_calib=False,
):
self.dataset = dataset
self.tokenizer = tokenizer
self.batch_size = batch_size
self.pad_val = pad_val
self.pad_max = pad_max
self.is_calib = is_calib

# tokenize the dataset
self.dataset = self.dataset.map(self.tokenize_function, batched=True)
self.dataset.set_format(type="torch", columns=["input_ids"])

@torch.no_grad()
def tokenize_function(self, examples):
example = self.tokenizer(examples["text"])
return example

@torch.no_grad()
def collate_batch(self, batch):
input_ids_padded = []
last_ind = []
for text in batch:
input_ids = text["input_ids"]
pad_len = self.pad_max - input_ids.shape[0]
last_ind.append(input_ids.shape[0] - 1)
if self.is_calib:
input_ids = (
input_ids[: self.pad_max]
if len(input_ids) > self.pad_max
else input_ids
)
else:
input_ids = pad(input_ids, (0, pad_len), value=self.pad_val)
input_ids_padded.append(input_ids)
return (
torch.vstack(input_ids_padded),
torch.tensor(last_ind),
)

calib_dataset = load_dataset(args.dataset, split="train")
calib_dataset = calib_dataset.shuffle(seed=42)
calib_evaluator = Evaluator(
calib_dataset,
tokenizer,
args.batch_size,
pad_max=args.pad_max_length,
is_calib=True,
)
calib_dataloader = DataLoader(
calib_evaluator.dataset,
batch_size=calib_size,
shuffle=False,
collate_fn=calib_evaluator.collate_batch,
)
input_ids = user_model.dummy_inputs["input_ids"]
input_bs, input_len = input_ids.shape
past_key_values = generate_dummy_past_key_values(input_bs, user_model)
attention_mask = torch.ones(input_bs, input_len + 1)
attention_mask[:,0] = 0
example_inputs = (input_ids, tuple(past_key_values), attention_mask)
# do inference to check example_inputs formats
user_model(*example_inputs)

def calib_func(prepared_model):
for i, (input_ids, last_ind) in enumerate(calib_dataloader):
input_bs, input_len = input_ids.shape
past_key_values = generate_dummy_past_key_values(input_bs, user_model)
attention_mask = torch.ones(input_bs, input_len + 1)
attention_mask[:,0] = 0
if i >= args.calib_iters:
break
prepared_model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
)

from neural_compressor import PostTrainingQuantConfig, quantization

if re.search("gptj", user_model.config.model_type) or re.search(
"gpt_neox", user_model.config.model_type
from intel_extension_for_transformers.transformers import (
AMPConfig,
changwangss marked this conversation as resolved.
Show resolved Hide resolved
WeightOnlyQuantizationConfig,
changwangss marked this conversation as resolved.
Show resolved Hide resolved
SmoothQuantConfig,
BitsAndBytesConfig

)
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
if re.search("gptj", config.model_type) or re.search(
"gpt_neox", config.model_type
):
op_type_dict = {
"add": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}},
}
elif re.search("mpt", user_model.config.model_type):
elif re.search("mpt", config.model_type):
op_type_dict = {
"add": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}},
"<built-in function linear>":{"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}},
}
else:
op_type_dict = {}
excluded_precisions = [] if args.int8_bf16_mixed else ["bf16"]
if args.sq:
args.alpha = args.alpha if args.alpha == "auto" else float(args.alpha)
recipes = {"smooth_quant": True, "smooth_quant_args": {"alpha": args.alpha}}
conf = PostTrainingQuantConfig(
backend="ipex" if args.ipex else "default",
excluded_precisions=excluded_precisions,
op_type_dict=op_type_dict,
recipes=recipes,
example_inputs=example_inputs,
)
else:
conf = PostTrainingQuantConfig(
backend="ipex" if args.ipex else "default",
excluded_precisions=excluded_precisions,
op_type_dict=op_type_dict,
example_inputs=example_inputs,
)
# save config
user_model.config.save_pretrained(args.output_dir)
q_model = quantization.fit(
user_model,
conf,
calib_func=calib_func,
)
q_model.save(args.output_dir)
sq_config = SmoothQuantConfig(
hshen14 marked this conversation as resolved.
Show resolved Hide resolved
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
alpha=float(args.alpha), # default is 0.5
op_type_dict=op_type_dict, # default is {}
excluded_precisions=excluded_precisions, # default is []
)
# smooth-quant
q_model = AutoModelForCausalLM.from_pretrained(args.model,
quantization_config=sq_config
)
print("sq done.")
# weight-only
woq_config = WeightOnlyQuantizationConfig(algorithm="RTN", # default is "RTN"
bits=8, # default is 8
group_size=-1, # default is -1
scheme="sym", # default is sym
enable_full_range=True # default is True
)
woq_model = AutoModelForCausalLM.from_pretrained(args.model,
quantization_config=woq_config
)
print("woq done.")
# amp
amp_config = AMPConfig(dtype="bfloat16") # default is bfloat16
amp_model = AutoModelForCausalLM.from_pretrained(args.model,
quantization_config=amp_config
)
print("amp done.")
# bitsandbytes
bab_config = BitsAndBytesConfig()
bab_model = AutoModelForCausalLM.from_pretrained(args.model,
quantization_config=bab_config
)
print("bitsandbytes done.")
changwangss marked this conversation as resolved.
Show resolved Hide resolved


# Generation
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)
Expand Down
36 changes: 11 additions & 25 deletions intel_extension_for_transformers/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,19 @@
# limitations under the License.


from .config import (
AutoDistillationConfig,
DistillationConfig,
FlashDistillationConfig,
TFDistillationConfig,
NASConfig,
Provider,
PruningConfig,
QuantizationConfig,
WEIGHTS_NAME,
DynamicLengthConfig,
BenchmarkConfig,
PrunerV2,

)
from .distillation import (
DistillationCriterionMode,
SUPPORTED_DISTILLATION_CRITERION_MODE,
)
from .modeling import OptimizedModel, AutoModelForCausalLM
from .config import (WEIGHTS_NAME, AutoDistillationConfig, BenchmarkConfig,
DistillationConfig, DynamicLengthConfig,
FlashDistillationConfig, NASConfig, Provider, PrunerV2,
PruningConfig, QuantizationConfig, TFDistillationConfig)
from .distillation import (SUPPORTED_DISTILLATION_CRITERION_MODE,
DistillationCriterionMode)
from .mixture.auto_distillation import AutoDistillation
from .nas import NAS
from .optimizer import NoTrainerOptimizer, Orchestrate_optimizer
from .optimizer_tf import TFOptimization
from .pruning import PrunerConfig, PruningMode, SUPPORTED_PRUNING_MODE
from .quantization import QuantizationMode, SUPPORTED_QUANT_MODE
from .utils import metrics
from .utils import objectives
from .pruning import SUPPORTED_PRUNING_MODE, PrunerConfig, PruningMode
from .quantization import SUPPORTED_QUANT_MODE, QuantizationMode
from .utils import (AMPConfig, BitsAndBytesConfig, SmoothQuantConfig,
WeightOnlyQuantizationConfig, metrics, objectives)
from .utils.utility import LazyImport

from .modeling import AutoModelForCausalLM, OptimizedModel
Loading