From c1f6f4643981e3e0f31884ad7d8cf6a3a7e545b9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 3 Apr 2024 01:02:08 +0000 Subject: [PATCH 01/17] adding the files for ipex int8 serving of llms --- examples/large_models/ipex_llm_int8/README.md | 35 ++ .../large_models/ipex_llm_int8/llm_handler.py | 572 ++++++++++++++++++ .../model-config-llama2-7b-bf16.yaml | 22 + .../model-config-llama2-7b-int8-sq.yaml | 29 + .../model-config-llama2-7b-int8-woq.yaml | 29 + .../ipex_llm_int8/sample_text_0.txt | 1 + 6 files changed, 688 insertions(+) create mode 100644 examples/large_models/ipex_llm_int8/README.md create mode 100644 examples/large_models/ipex_llm_int8/llm_handler.py create mode 100644 examples/large_models/ipex_llm_int8/model-config-llama2-7b-bf16.yaml create mode 100644 examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-sq.yaml create mode 100644 examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-woq.yaml create mode 100644 examples/large_models/ipex_llm_int8/sample_text_0.txt diff --git a/examples/large_models/ipex_llm_int8/README.md b/examples/large_models/ipex_llm_int8/README.md new file mode 100644 index 0000000000..cfc645cb15 --- /dev/null +++ b/examples/large_models/ipex_llm_int8/README.md @@ -0,0 +1,35 @@ +This example provides an example of serving IPEX optimized LLMs e.g. meta-llama/llama2-7b-hf on huggingface. For setting up the python environment for this example, please refer here: https://github.com/intel/intel-extension-for-pytorch/blob/main/examples/cpu/inference/python/llm/README.md#3-environment-setup + +You can choose either Weight only Quantization or Smoothquant path for quantizing the model to INT8.If quant_with_amp flag is set to true, it'll use mix of INT8 and bfloat16 precisions, otherwise it'll use INT8 and FP32 combination. If neither approaches are enabled, the model runs on bfloat16 precision by default as long as quant_with_amp is set to true. +There are 3 different example config files; model-config-llama2-7b-int8-sq.yaml for quantizing with smoothquant, model-config-llama2-7b-int8-woq.yaml for quantizing with weight only quantization, and model-config-llama2-7b-bf16.yaml for running the text generation on bfloat16 precision. + +1. Zip everything using model archiver +``` +torch-model-archiver --model-name llama2-7b --version 1.0 --handler llm_handler.py --config-file model-config-llama2-7b-int8-woq.yaml +``` + +2. Move archive to model_store +``` +mkdir model_store +mv llama2-7b.mar ./model_store +``` + +3. Start the torch server +``` +torchserve --ncs --start --model-store model_store +``` + +4. From the client, set up batching parameters. I couldn't make it work by putting the max_batch_size and max_batch_delay in config.properties. +``` +curl -X POST "localhost:8081/models?url=llama2-7b.mar&batch_size=4&max_batch_delay=100" +``` + +5. Test the model status +``` +curl http://localhost:8081/models/llama2-7b +``` + +6. Send the request +``` +curl http://localhost:8080/predictions/llama2-7b -T ./sample_text_0.txt +``` diff --git a/examples/large_models/ipex_llm_int8/llm_handler.py b/examples/large_models/ipex_llm_int8/llm_handler.py new file mode 100644 index 0000000000..f56514cf8b --- /dev/null +++ b/examples/large_models/ipex_llm_int8/llm_handler.py @@ -0,0 +1,572 @@ +import os +import logging +from abc import ABC +from pathlib import Path +import subprocess + +import torch +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + +from datasets import load_dataset +from torch.utils.data import DataLoader + +from ts.context import Context +from ts.torch_handler.base_handler import BaseHandler +import intel_extension_for_pytorch as ipex + + +EXAMPLE_INPUTS_MODE = { + "MASK_KV": 1, + "KV_MASK": 2, + "MASK_POS_KV": 3, + "MASK_KV_POS": 4, + "MASK_KV_ENC": 5, +} + + +logger = logging.getLogger(__name__) +logger.info("PyTorch version %s", torch.__version__) +logger.info("IPEX version %s", ipex.__version__) +logger.info("Transformers version %s", transformers.__version__) + +class CodeGenHandler(BaseHandler, ABC): + + def __init__(self): + super(CodeGenHandler, self).__init__() + + # for streaming the generated texts back to client + self.output_streamer = None + + + def initialize(self, ctx: Context): + model_name = ctx.model_yaml_config["handler"]["model_name"] + # path to quantized model, if we are quantizing on the fly, we'll use this path to save the model + self.quantized_model_path = ctx.model_yaml_config["handler"]["quantized_model_path"] + self.example_inputs_mode = ctx.model_yaml_config["handler"]["example_inputs_mode"] + self.to_channels_last = ctx.model_yaml_config["handler"]["to_channels_last"] + + # generation params + self.batch_size = int(ctx.model_yaml_config["handler"]["batch_size"]) + self.max_context_length = int(ctx.model_yaml_config["handler"]["max_context_length"]) + self.input_tokens = int(ctx.model_yaml_config["handler"]["input_tokens"]) + self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) + + # use int8 bf16 mix + self.quant_with_amp = ctx.model_yaml_config["handler"]["quant_with_amp"] + + # WoQ related optimization params + if "ipex_weight_only_quantization" in ctx.model_yaml_config["handler"]: + self.ipex_weight_only_quantization = ctx.model_yaml_config["handler"]["ipex_weight_only_quantization"] + self.woq_dtype = ctx.model_yaml_config["handler"]["woq_dtype"] + self.lowp_mode = ctx.model_yaml_config["handler"]["lowp_mode"] + self.act_quant_mode = ctx.model_yaml_config["handler"]["act_quant_mode"] # This is only relevant for INT4x2 quantization + self.group_size = ctx.model_yaml_config["handler"]["group_size"] + else: + self.ipex_weight_only_quantization = False + + # SQ related optimization params + if "ipex_smooth_quantization" in ctx.model_yaml_config["handler"]: + self.ipex_smooth_quantization = ctx.model_yaml_config["handler"]["ipex_smooth_quantization"] + self.calib_dataset = ctx.model_yaml_config["handler"]["calibration_dataset"] + self.calib_split = ctx.model_yaml_config["handler"]["calibration_split"] + self.num_calib_iters = int(ctx.model_yaml_config["handler"]["num_calibration_iters"]) + self.alpha = float(ctx.model_yaml_config["handler"]["alpha"]) + else: + self.ipex_smooth_quantization = False + + + # decoding parameters + self.greedy = ctx.model_yaml_config["handler"]["greedy"] + logger.info(f"Max length of the sequence context is {self.max_context_length}") + + try: + ipex._C.disable_jit_linear_repack() + torch._C._jit_set_texpr_fuser_enabled(False) + except Exception: + pass + + # amp datatype + if self.quant_with_amp: + self.amp_enabled = True + self.amp_dtype = torch.bfloat16 + else: + self.amp_enabled = False + self.amp_dtype = torch.float32 + + # generate args: using greedy for now + self.num_beams = 1 if self.greedy else 4 + # donot use min number of tokens on demo mode, only use it on benchmark mode + self.generate_kwargs = dict( + do_sample=False, + temperature=0.9, + num_beams=self.num_beams, + max_new_tokens=self.max_new_tokens, + min_new_tokens=self.max_new_tokens, + ) + + # device + device = torch.device("cpu") + + # model config + config = AutoConfig.from_pretrained(model_name, torchscript=True, trust_remote_code=True) + + # set up max context + if not hasattr(config, "text_max_length"): + config.text_max_length = int(self.max_context_length) + + # load model and tokenizer + self.user_model = AutoModelForCausalLM.from_pretrained(model_name, config=config, low_cpu_mem_usage=True, torch_dtype=torch.float) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code = True) + + logger.info("Data type of the model: %s", self.user_model.dtype) + + if self.to_channels_last: + self.user_model = self.user_model.to(memory_format=torch.channels_last) + self.user_model.eval() + + + # dummy past key value + beam_idx_tmp = torch.zeros((2048, int(self.batch_size * self.num_beams)), dtype=torch.long).contiguous() + def _get_target_nums(names): + for n in names: + if hasattr(self.user_model.config, n): + return getattr(self.user_model.config, n) + logger.error(f"Not found target {names[0]}") + exit(0) + + num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] + num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] + hidden_size_names = ["hidden_size", "n_embd"] + n_heads = _get_target_nums(num_heads_names) + n_layers = _get_target_nums(num_layers_names) + hidden_size = _get_target_nums(hidden_size_names) + head_dim = int(hidden_size / n_heads) + self.global_past_key_value = [ + ( + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + beam_idx_tmp, + ) + for i in range(n_layers) + ] + + logger.info(f"num_attention_heads: {n_heads}, num_hidden_layers: {n_layers}, hidden size: {hidden_size}, head_dim: {head_dim}") + + if self.ipex_smooth_quantization and self.ipex_weight_only_quantization: + logger.error("Can't enable both SQ and WoQ, enable only one of them") + exit(1) + + # lets implement the WOQ + if self.ipex_weight_only_quantization: + weight_dtype = torch.quint4x2 if self.woq_dtype == "INT4" else torch.qint8 + + if self.lowp_mode == "INT8": + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + elif self.lowp_mode == "FP32": + lowp_mode = ipex.quantization.WoqLowpMode.NONE + elif self.lowp_mode == "FP16": + lowp_mode = ipex.quantization.WoqLowpMode.FP16 + elif self.lowp_mode == "BF16": + lowp_mode = ipex.quantization.WoqLowpMode.BF16 + else: + lowp_mode = ipex.quantization.WoqLowpMode.BF16 + + act_quant_mode_dict = { + "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR, + "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH, + "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK, + } + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode_dict[self.act_quant_mode], + group_size=self.group_size, + ) + + # low precision checkpoint can be loaded, but we're considering there isn't any + low_precision_checkpoint = None + self.user_model = ipex.llm.optimize( + self.user_model.eval(), + dtype=self.amp_dtype, + quantization_config=qconfig, + inplace=True, + low_precision_checkpoint=low_precision_checkpoint, + deployment_mode=False, + ) + logger.info("The model conversion completed, now tracing the quantized model") + + example_inputs = self.get_example_inputs() + + with torch.no_grad(), torch.cpu.amp.autocast( + enabled=self.amp_enabled, + dtype=self.amp_dtype + ): + self_jit = torch.jit.trace(self.user_model.eval(), example_inputs, strict=False, check_trace=False) + self_jit = torch.jit.freeze(self_jit.eval()) + + self_jit.save(self.quantized_model_path) + + logger.info("The IPEX Weight only quantization has been completed successfully") + + elif self.ipex_smooth_quantization: + class Evaluator: + def __init__(self, example_inputs_mode, global_past_key_value, dataset, tokenizer, batch_size=1, num_beams=1, pad_val=1, pad_max=512): + self.example_inputs_mode = example_inputs_mode + self.global_past_key_value = global_past_key_value + self.dataset = dataset + self.tokenizer = tokenizer + self.batch_size = batch_size + self.num_beams = num_beams + + + self.pad_val = pad_val + self.pad_max = pad_max + self.dataset = self.dataset.map(self.tokenize_function, batched = True, num_proc=2) + self.dataset.set_format(type="torch", columns=["input_ids"]) + + @torch.no_grad() + def tokenize_function(self, examples): + if "prompt" in examples: + example = self.tokenizer(examples["prompt"]) + elif "text" in examples: + example = self.tokenizer(examples["text"]) + elif "code" in examples: + example = self.tokenizer(examples["code"]) + return example + + + @torch.no_grad() + def collate_batch(self, batch): + position_ids_padded = [] + input_ids_padded = [] + last_ind = [] + attention_mask_padded = [] + + for text in batch: + input_ids = text["input_ids"] + last_ind.append(input_ids.shape[0] - 1) + attention_mask = torch.ones(len(input_ids)) + position_ids = torch.arange(len(input_ids)) + + input_ids_padded.append(input_ids) + attention_mask_padded.append(attention_mask) + position_ids_padded.append(position_ids) + + if self.example_inputs_mode == "MASK_POS_KV": + model_inputs = ( + torch.vstack(input_ids_padded), + torch.vstack(attention_mask_padded), + torch.vstack(position_ids_padded), + tuple(self.global_past_key_value), + ) + elif self.example_inputs_mode == "MASK_KV_POS": + model_inputs = ( + torch.vstack(input_ids_padded), + torch.vstack(attention_mask_padded), + tuple(self.global_past_key_value), + torch.vstack(position_ids_padded), + ) + elif self.example_inputs_mode == "KV_MASK": + model_inputs = ( + torch.vstack(input_ids_padded), + tuple(self.global_past_key_value), + torch.vstack(attention_mask_padded), + ) + elif self.example_inputs_mode == "MASK_KV": + model_inputs = ( + torch.vstack(input_ids_padded), + torch.vstack(attention_mask_padded), + tuple(self.global_past_key_value), + ) + elif self.example_inputs_mode == "MASK_KV_ENC": + model_kwargs = { + "attention_mask": torch.vstack(attention_mask_padded), + } + model_kwargs = user_model._prepare_encoder_decoder_kwargs_for_generation( + torch.vstack(input_ids_padded), model_kwargs, "input_ids" + ) + input_ids, example_inputs = user_model._expand_inputs_for_generation( + input_ids=torch.vstack(input_ids_padded), + expand_size=self.num_beams, + is_encoder_decoder=True, + **model_kwargs, + ) + + # need to recompute these + def _get_target_nums(names): + for n in names: + if hasattr(self.user_model.config, n): + return getattr(self.user_model.config, n) + logger.error(f"Not found target {names[0]}") + exit(0) + + num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] + num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] + hidden_size_names = ["hidden_size", "n_embd"] + n_heads = _get_target_nums(num_heads_names) + n_layers = _get_target_nums(num_layers_names) + hidden_size = _get_target_nums(hidden_size_names) + head_dim = int(hidden_size / n_heads) + + # lets get the inputs + input_bs = int(self.batch_size * self.num_beams) + last_hidden_state = example_inputs["encoder_outputs"]["last_hidden_state"] + global_past_key_value = tuple( + [ + ( + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + beam_idx_tmp, + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + user_model.decoder.block[i].layer[1].EncDecAttention.k(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), + user_model.decoder.block[i].layer[1].EncDecAttention.v(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), + beam_idx_tmp, + ) + for i in range(n_layers) + ] + ) + + decoder_input_ids = (torch.zeros(input_bs).to(torch.long).unsqueeze(1)) + model_inputs = ( + decoder_input_ids, + torch.vstack(attention_mask_padded), + tuple(global_past_key_value), + (last_hidden_state,), + ) + else: + raise RuntimeError("Your model does not match existing example inputs used in ipex smooth quant, exiting...") + + return (model_inputs, last_ind) + + + + calib_dataset = load_dataset(self.calib_dataset, split=self.calib_split) + logger.info(f"Dataset loaded: {calib_dataset}") + calib_evaluator = Evaluator( + self.example_inputs_mode, + self.global_past_key_value, + calib_dataset, + self.tokenizer, + batch_size=self.batch_size, + num_beams = self.num_beams, + pad_max = 512 + ) + logger.info(f"Evaluator built: {calib_evaluator}") + + calib_dataloader = DataLoader( + calib_evaluator.dataset, + batch_size=1, + shuffle=False, + collate_fn=calib_evaluator.collate_batch, + ) + logger.info("Dataloader ready") + + + from intel_extension_for_pytorch.quantization import prepare, convert + example_inputs = self.get_example_inputs() + qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=self.alpha) + user_model = ipex.llm.optimize( + self.user_model.eval(), + dtype=self.amp_dtype, + quantization_config=qconfig, + inplace=True, + deployment_mode=False, + ) + + prepared_model = prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True) + logger.info("Model prepared for quantization, observers inserted") + + + for i, (model_inputs, last_ind) in enumerate(calib_dataloader): + if i == self.num_calib_iters: + break + prepared_model(*model_inputs) + logger.info("Model calibration completed") + + converted_model = convert(prepared_model.eval(), inplace=True).eval() + logger.info("Model converted successfully, exporting the trace") + + with torch.no_grad(), torch.cpu.amp.autocast( + enabled=self.amp_enabled, + dtype=self.amp_dtype + ): + self_jit = torch.jit.trace(converted_model.eval(), example_inputs, strict=False, check_trace=False) + self_jit = torch.jit.freeze(self_jit.eval()) + + self_jit.save(self.quantized_model_path) + + logger.info("IPEX Smooth Quantization has completed successfully") + + else: + # run bf16 model + example_inputs = self.get_example_inputs() + self.user_model = ipex.llm.optimize( + self.user_model.eval(), + dtype = self.amp_dtype, + inplace=True, + deployment_mode=False, + ) + + with torch.no_grad(), torch.cpu.amp.autocast( + enabled=self.amp_enabled, + dtype=self.amp_dtype + ): + self_jit = torch.jit.trace(self.user_model.eval(), example_inputs, strict=False, check_trace=False) + self_jit = torch.jit.freeze(self_jit.eval()) + + self_jit.save(self.quantized_model_path) + + logger.info("IPEX bf16 optimization is applied successfully") + + # set PAD token + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token=self.tokenizer.eos_token + + logger.info("Loading the IPEX quantized model") + try: + self_jit = torch.jit.load(self.quantized_model_path) + self_jit = torch.jit.freeze(self_jit.eval()) + except Exception as e: + logger.error("Error: loading the quantized model failed.", e) + exit(0) + + setattr(self.user_model, "trace_graph", self_jit) + logger.info("Successfully loaded the Model %s with Intel® Extension for PyTorch*", ctx.model_name) + + # Different model need to have their inputs supplied in different order unless we pass dict + # For torchserve sending dict is not always possible + # This function reorders the input ids, masks, and kv cache based on models + def get_example_inputs(self): + example_inputs = None + input_ids = torch.ones(32).to(torch.long) + attention_mask = torch.ones(len(input_ids)) + if self.example_inputs_mode == "MASK_POS_KV": + position_ids = torch.arange(len(input_ids)) + example_inputs = ( + input_ids.unsqueeze(0), + attention_mask.unsqueeze(0), + position_ids.unsqueeze(0), + tuple(self.global_past_key_value), + ) + elif self.example_inputs_mode == "MASK_KV_POS": + position_ids = torch.arange(len(input_ids)) + example_inputs = ( + input_ids.unsqueeze(0), + attention_mask.unsqueeze(0), + tuple(self.global_past_key_value), + position_ids.unsqueeze(0), + ) + elif self.example_inputs_mode == "KV_MASK": + example_inputs = ( + input_ids.unsqueeze(0), + tuple(self.global_past_key_value), + attention_mask.unsqueeze(0), + ) + elif self.example_inputs_mode == "MASK_KV": + example_inputs = ( + input_ids.unsqueeze(0), + attention_mask.unsqueeze(0), + tuple(self.global_past_key_value), + ) + elif self.example_inputs_mode == "MASK_KV_ENC": + last_hidden_state = torch.rand([1, 32, 2048]) + + #need to recompute these + def _get_target_nums(names): + for n in names: + if hasattr(self.user_model.config, n): + return getattr(self.user_model.config, n) + logger.error(f"Not found target {names[0]}") + exit(0) + + num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] + num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] + hidden_size_names = ["hidden_size", "n_embd"] + n_heads = _get_target_nums(num_heads_names) + n_layers = _get_target_nums(num_layers_names) + hidden_size = _get_target_nums(hidden_size_names) + head_dim = int(hidden_size / n_heads) + + global_past_key_value = [ + ( + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + beam_idx_tmp, + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + torch.zeros([32, 1, n_heads, head_dim]).contiguous(), + torch.zeros([32, 1, n_heads, head_dim]).contiguous(), + beam_idx_tmp, + ) + for i in range(n_layers) + ] + example_inputs = ( + torch.ones(1).to(torch.long).unsqueeze(0), + attention_mask.unsqueeze(0), + tuple(global_past_key_value), + (last_hidden_state,), + ) + else: + raise RuntimeError("Your model does not match existing example inputs used in ipex quantization, exiting...") + #if hasattr(model, "extra_inputs"): + # example_inputs = example_inputs + model.extra_inputs + return example_inputs + + def preprocess(self, requests): + input_ids_batch = None + attention_mask_batch = None + for idx, data in enumerate(requests): + input_text = data.get("data") + if input_text is None: + input_text = data.get("body") + if isinstance(input_text, (bytes, bytearray)): + input_text = input_text.decode("utf-8") + + with torch.inference_mode(), torch.no_grad(), torch.autocast( + device_type="cpu", + enabled=self.amp_enabled, + dtype=self.amp_dtype + ): + inputs = self.tokenizer( + input_text, + pad_to_max_length=True, + add_special_tokens=True, + return_tensors="pt", + #max_length=int(self.max_length), + ) + + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + # making a batch out of the recieved requests + if input_ids.shape is not None: + if input_ids_batch is None: + input_ids_batch = input_ids + attention_mask_batch = attention_mask + else: + input_ids_batch = torch.cat((input_ids_batch, input_ids), 0) + attention_mask_batch = torch.cat((attention_mask_batch, attention_mask), 0) + return (input_ids_batch, attention_mask_batch) + + def inference(self, input_batch): + input_ids_batch, attention_mask_batch = input_batch + inferences = [] + # total_list = [] + + with torch.inference_mode(), torch.no_grad(), torch.autocast( + device_type="cpu", + enabled=self.amp_enabled, + dtype=self.amp_dtype + ): + outputs = self.user_model.generate(input_ids_batch, attention_mask=attention_mask_batch, **self.generate_kwargs) + for i, x in enumerate(outputs): + inferences.append(self.tokenizer.decode(outputs[i], skip_special_tokens=True)) + + return inferences + + def postprocess(self, inference_output): + return inference_output diff --git a/examples/large_models/ipex_llm_int8/model-config-llama2-7b-bf16.yaml b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-bf16.yaml new file mode 100644 index 0000000000..92dc955b96 --- /dev/null +++ b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-bf16.yaml @@ -0,0 +1,22 @@ +minWorkers: 1 +maxWorkers: 1 +responseTimeout: 1500 + +handler: + model_name: "meta-llama/Llama-2-7b-hf" + quantized_model_path: "best_model.pt" + example_inputs_mode: "MASK_KV_POS" + to_channels_last: false + + # generation params + batch_size: 1 + max_context_length: 2048 + input_tokens: 1024 + max_new_tokens: 128 + + # Use INT8 bf16 mix + quant_with_amp: true + + # decoding technique + greedy: true + diff --git a/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-sq.yaml b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-sq.yaml new file mode 100644 index 0000000000..a82016abda --- /dev/null +++ b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-sq.yaml @@ -0,0 +1,29 @@ +minWorkers: 1 +maxWorkers: 1 +responseTimeout: 1500 + +handler: + model_name: "meta-llama/Llama-2-7b-hf" + quantized_model_path: "best_model.pt" + example_inputs_mode: "MASK_KV_POS" + to_channels_last: false + + # generation params + batch_size: 1 + max_context_length: 2048 + input_tokens: 1024 + max_new_tokens: 128 + + # use bf16-int8 mix + quant_with_amp: true + + # SQ quantization params + ipex_smooth_quantization: true + calibration_dataset: "NeelNanda/pile-10k" + calibration_split: "train" + num_calibration_iters: 32 + alpha: 0.9 + + # decoding technique + greedy: true + diff --git a/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-woq.yaml b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-woq.yaml new file mode 100644 index 0000000000..29cbda64f7 --- /dev/null +++ b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-woq.yaml @@ -0,0 +1,29 @@ +minWorkers: 1 +maxWorkers: 1 +responseTimeout: 1500 + +handler: + model_name: "meta-llama/Llama-2-7b-hf" + quantized_model_path: "best_model.pt" + example_inputs_mode: "MASK_KV_POS" + to_channels_last: false + + # generation params + batch_size: 1 + max_context_length: 2048 + input_tokens: 1024 + max_new_tokens: 128 + + # Use INT8 bf16 mix + quant_with_amp: true + + # Woq params + ipex_weight_only_quantization: true + woq_dtype: "INT8" + lowp_mode: "BF16" + act_quant_mode: "PER_IC_BLOCK" + group_size: -1 + + # decoding technique + greedy: true + diff --git a/examples/large_models/ipex_llm_int8/sample_text_0.txt b/examples/large_models/ipex_llm_int8/sample_text_0.txt new file mode 100644 index 0000000000..12faa9ebf1 --- /dev/null +++ b/examples/large_models/ipex_llm_int8/sample_text_0.txt @@ -0,0 +1 @@ +Why did Voldemort hate Harry Potter? From 055de91ccbc9928dac8f40af588a888e0e343f8e Mon Sep 17 00:00:00 2001 From: bbhattar <113475728+bbhattar@users.noreply.github.com> Date: Tue, 2 Apr 2024 18:13:20 -0700 Subject: [PATCH 02/17] Update README.md Fixed some markdowns --- examples/large_models/ipex_llm_int8/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/large_models/ipex_llm_int8/README.md b/examples/large_models/ipex_llm_int8/README.md index cfc645cb15..7f4b5202ca 100644 --- a/examples/large_models/ipex_llm_int8/README.md +++ b/examples/large_models/ipex_llm_int8/README.md @@ -1,9 +1,9 @@ -This example provides an example of serving IPEX optimized LLMs e.g. meta-llama/llama2-7b-hf on huggingface. For setting up the python environment for this example, please refer here: https://github.com/intel/intel-extension-for-pytorch/blob/main/examples/cpu/inference/python/llm/README.md#3-environment-setup +This example provides an example of serving IPEX-optimized LLMs e.g. ```meta-llama/llama2-7b-hf``` on huggingface. For setting up the Python environment for this example, please refer here: https://github.com/intel/intel-extension-for-pytorch/blob/main/examples/cpu/inference/python/llm/README.md#3-environment-setup -You can choose either Weight only Quantization or Smoothquant path for quantizing the model to INT8.If quant_with_amp flag is set to true, it'll use mix of INT8 and bfloat16 precisions, otherwise it'll use INT8 and FP32 combination. If neither approaches are enabled, the model runs on bfloat16 precision by default as long as quant_with_amp is set to true. -There are 3 different example config files; model-config-llama2-7b-int8-sq.yaml for quantizing with smoothquant, model-config-llama2-7b-int8-woq.yaml for quantizing with weight only quantization, and model-config-llama2-7b-bf16.yaml for running the text generation on bfloat16 precision. +You can choose either Weight-only Quantization or Smoothquant path for quantizing the model to ```INT8```. If the ```quant_with_amp``` flag is set to ```true```, it'll use a mix of ```INT8``` and ```bfloat16``` precisions, otherwise, it'll use ```INT8``` and ```FP32``` combination. If neither approaches are enabled, the model runs on ```bfloat16``` precision by default as long as ```quant_with_amp``` is set to ```true```. +There are 3 different example config files; ```model-config-llama2-7b-int8-sq.yaml``` for quantizing with smooth-quant, ```model-config-llama2-7b-int8-woq.yaml``` for quantizing with weight only quantization, and ```model-config-llama2-7b-bf16.yaml``` for running the text generation on bfloat16 precision. -1. Zip everything using model archiver +1. Zip everything using the model archiver ``` torch-model-archiver --model-name llama2-7b --version 1.0 --handler llm_handler.py --config-file model-config-llama2-7b-int8-woq.yaml ``` From 5fc1d677677e9c288c51946342e3b572833895ed Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 4 Apr 2024 16:32:41 +0000 Subject: [PATCH 03/17] Fix handler name --- examples/large_models/ipex_llm_int8/llm_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/large_models/ipex_llm_int8/llm_handler.py b/examples/large_models/ipex_llm_int8/llm_handler.py index f56514cf8b..d2e13d2e0f 100644 --- a/examples/large_models/ipex_llm_int8/llm_handler.py +++ b/examples/large_models/ipex_llm_int8/llm_handler.py @@ -30,10 +30,10 @@ logger.info("IPEX version %s", ipex.__version__) logger.info("Transformers version %s", transformers.__version__) -class CodeGenHandler(BaseHandler, ABC): +class IpexLLMHandler(BaseHandler, ABC): def __init__(self): - super(CodeGenHandler, self).__init__() + super(IpexLLMHandler, self).__init__() # for streaming the generated texts back to client self.output_streamer = None From 4c617e0111d618aa85a2bb3d9a0b6af5d08a7e76 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 7 May 2024 21:50:38 +0000 Subject: [PATCH 04/17] Adding default PyTorch support --- .../large_models/ipex_llm_int8/llm_handler.py | 651 ++++++++++-------- test/pytest/test_ipex_serving.py | 155 +++++ 2 files changed, 504 insertions(+), 302 deletions(-) create mode 100644 test/pytest/test_ipex_serving.py diff --git a/examples/large_models/ipex_llm_int8/llm_handler.py b/examples/large_models/ipex_llm_int8/llm_handler.py index d2e13d2e0f..a32fa4870f 100644 --- a/examples/large_models/ipex_llm_int8/llm_handler.py +++ b/examples/large_models/ipex_llm_int8/llm_handler.py @@ -2,34 +2,42 @@ import logging from abc import ABC from pathlib import Path -import subprocess +import re import torch import transformers from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers import T5ForConditionalGeneration from datasets import load_dataset from torch.utils.data import DataLoader from ts.context import Context from ts.torch_handler.base_handler import BaseHandler -import intel_extension_for_pytorch as ipex - - -EXAMPLE_INPUTS_MODE = { - "MASK_KV": 1, - "KV_MASK": 2, - "MASK_POS_KV": 3, - "MASK_KV_POS": 4, - "MASK_KV_ENC": 5, -} - logger = logging.getLogger(__name__) logger.info("PyTorch version %s", torch.__version__) -logger.info("IPEX version %s", ipex.__version__) logger.info("Transformers version %s", transformers.__version__) +IPEX_ENABLE = False +if os.environ.get("TS_IPEX_ENABLE", "false") == "true": + try: + import intel_extension_for_pytorch as ipex + try: + ipex._C.disable_jit_linear_repack() + torch._C._jit_set_texpr_fuser_enabled(False) + except Exception: + pass + IPEX_ENABLE = True + logger.info("IPEX optimization is enabled") + logger.info("IPEX version %s", ipex.__version__) + + except ImportError as error: + logger.warning("IPEX is enabled but intel-extension-for-pytorch cannot be imported. Proceeding without IPEX") + IPEX_ENABLE = False +else: + logger.warning("IPEX is not enabled, consider enabling it for best performance on Intel hardware") + class IpexLLMHandler(BaseHandler, ABC): def __init__(self): @@ -42,17 +50,17 @@ def __init__(self): def initialize(self, ctx: Context): model_name = ctx.model_yaml_config["handler"]["model_name"] # path to quantized model, if we are quantizing on the fly, we'll use this path to save the model + self.clear_cache_dir = ctx.model_yaml_config["handler"]["clear_cache_dir"] self.quantized_model_path = ctx.model_yaml_config["handler"]["quantized_model_path"] self.example_inputs_mode = ctx.model_yaml_config["handler"]["example_inputs_mode"] self.to_channels_last = ctx.model_yaml_config["handler"]["to_channels_last"] # generation params self.batch_size = int(ctx.model_yaml_config["handler"]["batch_size"]) - self.max_context_length = int(ctx.model_yaml_config["handler"]["max_context_length"]) self.input_tokens = int(ctx.model_yaml_config["handler"]["input_tokens"]) self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) - # use int8 bf16 mix + # use int8 bf16 mix self.quant_with_amp = ctx.model_yaml_config["handler"]["quant_with_amp"] # WoQ related optimization params @@ -78,13 +86,6 @@ def initialize(self, ctx: Context): # decoding parameters self.greedy = ctx.model_yaml_config["handler"]["greedy"] - logger.info(f"Max length of the sequence context is {self.max_context_length}") - - try: - ipex._C.disable_jit_linear_repack() - torch._C._jit_set_texpr_fuser_enabled(False) - except Exception: - pass # amp datatype if self.quant_with_amp: @@ -113,11 +114,28 @@ def initialize(self, ctx: Context): # set up max context if not hasattr(config, "text_max_length"): - config.text_max_length = int(self.max_context_length) - - # load model and tokenizer - self.user_model = AutoModelForCausalLM.from_pretrained(model_name, config=config, low_cpu_mem_usage=True, torch_dtype=torch.float) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code = True) + config.text_max_length = int(self.input_tokens) + int(self.max_new_tokens) + if "mpt" in model_name and not hasattr(config, "max_seq_len"): + config.max_seq_len = int(self.input_tokens) + int(self.max_new_tokens) + + # load model and tokenizer, + # We need special provision for t5 because it's seq2seq model, and can not be loaded with AutoModelForCausalLM + if re.search("t5", config.architectures[0], re.IGNORECASE): + self.user_model = T5ForConditionalGeneration.from_pretrained(model_name, config=config, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float) + input_ids = torch.ones(32).to(torch.long).unsqueeze(0) + attention_mask = torch.ones_like(input_ids) + dummy_inputs = self.user_model.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask) + if dummy_inputs.get("position_ids", None) is not None: + self.example_inputs_mode = "MASK_KV_POS" + + # we also need to update generation kwargs + self.generate_kwargs["max_length"] = self.generate_kwargs["max_new_tokens"] + self.generate_kwargs.pop("max_new_tokens") + + else: + self.user_model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) logger.info("Data type of the model: %s", self.user_model.dtype) @@ -127,7 +145,7 @@ def initialize(self, ctx: Context): # dummy past key value - beam_idx_tmp = torch.zeros((2048, int(self.batch_size * self.num_beams)), dtype=torch.long).contiguous() + self.beam_idx_tmp = torch.zeros((2048, int(self.batch_size * self.num_beams)), dtype=torch.long).contiguous() def _get_target_nums(names): for n in names: if hasattr(self.user_model.config, n): @@ -137,7 +155,7 @@ def _get_target_nums(names): num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] - hidden_size_names = ["hidden_size", "n_embd"] + hidden_size_names = ["hidden_size", "n_embd", "d_model"] n_heads = _get_target_nums(num_heads_names) n_layers = _get_target_nums(num_layers_names) hidden_size = _get_target_nums(hidden_size_names) @@ -147,296 +165,322 @@ def _get_target_nums(names): torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), torch.zeros([1, n_heads, 1, head_dim]).contiguous(), torch.zeros([1, n_heads, 1, head_dim]).contiguous(), - beam_idx_tmp, + self.beam_idx_tmp, ) for i in range(n_layers) ] logger.info(f"num_attention_heads: {n_heads}, num_hidden_layers: {n_layers}, hidden size: {hidden_size}, head_dim: {head_dim}") - if self.ipex_smooth_quantization and self.ipex_weight_only_quantization: - logger.error("Can't enable both SQ and WoQ, enable only one of them") - exit(1) - - # lets implement the WOQ - if self.ipex_weight_only_quantization: - weight_dtype = torch.quint4x2 if self.woq_dtype == "INT4" else torch.qint8 - - if self.lowp_mode == "INT8": - lowp_mode = ipex.quantization.WoqLowpMode.INT8 - elif self.lowp_mode == "FP32": - lowp_mode = ipex.quantization.WoqLowpMode.NONE - elif self.lowp_mode == "FP16": - lowp_mode = ipex.quantization.WoqLowpMode.FP16 - elif self.lowp_mode == "BF16": - lowp_mode = ipex.quantization.WoqLowpMode.BF16 - else: - lowp_mode = ipex.quantization.WoqLowpMode.BF16 - - act_quant_mode_dict = { - "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR, - "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, - "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH, - "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK, - } - - qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype, - lowp_mode=lowp_mode, - act_quant_mode=act_quant_mode_dict[self.act_quant_mode], - group_size=self.group_size, - ) - - # low precision checkpoint can be loaded, but we're considering there isn't any - low_precision_checkpoint = None - self.user_model = ipex.llm.optimize( - self.user_model.eval(), - dtype=self.amp_dtype, - quantization_config=qconfig, - inplace=True, - low_precision_checkpoint=low_precision_checkpoint, - deployment_mode=False, - ) - logger.info("The model conversion completed, now tracing the quantized model") + if IPEX_ENABLE: + """ + Ipex is enabled, we'll use + (1) weight only quantization if ipex_weight_only_quantization is enabled + (2) ipex smooth quantization if ipex_smooth_quantization is enabled + (3) ipex bfloat16 optimization if neither is quantization is enabled + (4) throws error if both 1 and 2 are enabled + """ + if self.ipex_smooth_quantization and self.ipex_weight_only_quantization: + logger.error("Can't enable both SQ and WoQ, enable only one of them") + exit(1) + + # Clear the cache dir if needed + if self.clear_cache_dir and os.path.exists(self.quantized_model_path): + os.remove(self.quantized_model_path) + + if os.path.exists(self.quantized_model_path): + # this skips all the optimizations and goes to end where we load the model + logger.info("A previously quantized model is loaded, if you want to re-quantize the model, enable clear_cache_dir on model config file") + + # lets implement the WOQ + elif self.ipex_weight_only_quantization: + weight_dtype = torch.quint4x2 if self.woq_dtype == "INT4" else torch.qint8 + + if self.lowp_mode == "INT8": + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + elif self.lowp_mode == "FP32": + lowp_mode = ipex.quantization.WoqLowpMode.NONE + elif self.lowp_mode == "FP16": + lowp_mode = ipex.quantization.WoqLowpMode.FP16 + elif self.lowp_mode == "BF16": + lowp_mode = ipex.quantization.WoqLowpMode.BF16 + else: + lowp_mode = ipex.quantization.WoqLowpMode.BF16 + + act_quant_mode_dict = { + "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR, + "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH, + "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK, + } + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode_dict[self.act_quant_mode], + group_size=self.group_size, + ) + + # low precision checkpoint can be loaded, but we're considering there isn't any + low_precision_checkpoint = None + self.user_model = ipex.llm.optimize( + self.user_model.eval(), + dtype=self.amp_dtype, + quantization_config=qconfig, + inplace=True, + low_precision_checkpoint=low_precision_checkpoint, + deployment_mode=False, + ) + logger.info("The model conversion completed, now tracing the quantized model") + + example_inputs = self.get_example_inputs() + + with torch.no_grad(), torch.cpu.amp.autocast( + enabled=self.amp_enabled, + dtype=self.amp_dtype + ): + self_jit = torch.jit.trace(self.user_model.eval(), example_inputs, strict=False, check_trace=False) + self_jit = torch.jit.freeze(self_jit.eval()) + + self_jit.save(self.quantized_model_path) + + logger.info("The IPEX Weight only quantization has been completed successfully") + + elif self.ipex_smooth_quantization: + class Evaluator: + def __init__(self, example_inputs_mode, global_past_key_value, dataset, tokenizer, batch_size=1, num_beams=1, pad_val=1, pad_max=512): + self.example_inputs_mode = example_inputs_mode + self.global_past_key_value = global_past_key_value + self.dataset = dataset + self.tokenizer = tokenizer + self.batch_size = batch_size + self.num_beams = num_beams + + + self.pad_val = pad_val + self.pad_max = pad_max + self.dataset = self.dataset.map(self.tokenize_function, batched = True, num_proc=2) + self.dataset.set_format(type="torch", columns=["input_ids"]) + + @torch.no_grad() + def tokenize_function(self, examples): + if "prompt" in examples: + example = self.tokenizer(examples["prompt"]) + elif "text" in examples: + example = self.tokenizer(examples["text"]) + elif "code" in examples: + example = self.tokenizer(examples["code"]) + return example + + + @torch.no_grad() + def collate_batch(self, batch): + position_ids_padded = [] + input_ids_padded = [] + last_ind = [] + attention_mask_padded = [] + + for text in batch: + input_ids = text["input_ids"] + last_ind.append(input_ids.shape[0] - 1) + attention_mask = torch.ones(len(input_ids)) + position_ids = torch.arange(len(input_ids)) + + input_ids_padded.append(input_ids) + attention_mask_padded.append(attention_mask) + position_ids_padded.append(position_ids) + + if self.example_inputs_mode == "MASK_POS_KV": + model_inputs = ( + torch.vstack(input_ids_padded), + torch.vstack(attention_mask_padded), + torch.vstack(position_ids_padded), + tuple(self.global_past_key_value), + ) + elif self.example_inputs_mode == "MASK_KV_POS": + model_inputs = ( + torch.vstack(input_ids_padded), + torch.vstack(attention_mask_padded), + tuple(self.global_past_key_value), + torch.vstack(position_ids_padded), + ) + elif self.example_inputs_mode == "KV_MASK": + model_inputs = ( + torch.vstack(input_ids_padded), + tuple(self.global_past_key_value), + torch.vstack(attention_mask_padded), + ) + elif self.example_inputs_mode == "MASK_KV": + model_inputs = ( + torch.vstack(input_ids_padded), + torch.vstack(attention_mask_padded), + tuple(self.global_past_key_value), + ) + elif self.example_inputs_mode == "MASK_KV_ENC": + model_kwargs = { + "attention_mask": torch.vstack(attention_mask_padded), + } + model_kwargs = user_model._prepare_encoder_decoder_kwargs_for_generation( + torch.vstack(input_ids_padded), model_kwargs, "input_ids" + ) + input_ids, example_inputs = user_model._expand_inputs_for_generation( + input_ids=torch.vstack(input_ids_padded), + expand_size=self.num_beams, + is_encoder_decoder=True, + **model_kwargs, + ) + + # need to recompute these + def _get_target_nums(names): + for n in names: + if hasattr(self.user_model.config, n): + return getattr(self.user_model.config, n) + logger.error(f"Not found target {names[0]}") + exit(0) + + num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] + num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] + hidden_size_names = ["hidden_size", "n_embd"] + n_heads = _get_target_nums(num_heads_names) + n_layers = _get_target_nums(num_layers_names) + hidden_size = _get_target_nums(hidden_size_names) + head_dim = int(hidden_size / n_heads) + + # lets get the inputs + input_bs = int(self.batch_size * self.num_beams) + last_hidden_state = example_inputs["encoder_outputs"]["last_hidden_state"] + global_past_key_value = tuple( + [ + ( + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + self.beam_idx_tmp, + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + user_model.decoder.block[i].layer[1].EncDecAttention.k(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), + user_model.decoder.block[i].layer[1].EncDecAttention.v(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), + self.beam_idx_tmp, + ) + for i in range(n_layers) + ] + ) + + decoder_input_ids = (torch.zeros(input_bs).to(torch.long).unsqueeze(1)) + model_inputs = ( + decoder_input_ids, + torch.vstack(attention_mask_padded), + tuple(global_past_key_value), + (last_hidden_state,), + ) + else: + raise RuntimeError("Your model does not match existing example inputs used in ipex smooth quant, exiting...") + + return (model_inputs, last_ind) + + + + calib_dataset = load_dataset(self.calib_dataset, split=self.calib_split) + logger.info(f"Dataset loaded: {calib_dataset}") + calib_evaluator = Evaluator( + self.example_inputs_mode, + self.global_past_key_value, + calib_dataset, + self.tokenizer, + batch_size=self.batch_size, + num_beams = self.num_beams, + pad_max = int(self.input_tokens) if re.search("t5", config.architectures[0], re.IGNORECASE) else 512 + ) + logger.info(f"Evaluator built: {calib_evaluator}") - example_inputs = self.get_example_inputs() + calib_dataloader = DataLoader( + calib_evaluator.dataset, + batch_size=1, + shuffle=False, + collate_fn=calib_evaluator.collate_batch, + ) + logger.info("Dataloader ready") + + + from intel_extension_for_pytorch.quantization import prepare, convert + example_inputs = self.get_example_inputs() + qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=self.alpha) + user_model = ipex.llm.optimize( + self.user_model.eval(), + dtype=self.amp_dtype, + quantization_config=qconfig, + inplace=True, + deployment_mode=False, + ) - with torch.no_grad(), torch.cpu.amp.autocast( - enabled=self.amp_enabled, - dtype=self.amp_dtype - ): - self_jit = torch.jit.trace(self.user_model.eval(), example_inputs, strict=False, check_trace=False) - self_jit = torch.jit.freeze(self_jit.eval()) + prepared_model = prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True) + logger.info("Model prepared for quantization, observers inserted") - self_jit.save(self.quantized_model_path) - - logger.info("The IPEX Weight only quantization has been completed successfully") - - elif self.ipex_smooth_quantization: - class Evaluator: - def __init__(self, example_inputs_mode, global_past_key_value, dataset, tokenizer, batch_size=1, num_beams=1, pad_val=1, pad_max=512): - self.example_inputs_mode = example_inputs_mode - self.global_past_key_value = global_past_key_value - self.dataset = dataset - self.tokenizer = tokenizer - self.batch_size = batch_size - self.num_beams = num_beams - - - self.pad_val = pad_val - self.pad_max = pad_max - self.dataset = self.dataset.map(self.tokenize_function, batched = True, num_proc=2) - self.dataset.set_format(type="torch", columns=["input_ids"]) - - @torch.no_grad() - def tokenize_function(self, examples): - if "prompt" in examples: - example = self.tokenizer(examples["prompt"]) - elif "text" in examples: - example = self.tokenizer(examples["text"]) - elif "code" in examples: - example = self.tokenizer(examples["code"]) - return example - - - @torch.no_grad() - def collate_batch(self, batch): - position_ids_padded = [] - input_ids_padded = [] - last_ind = [] - attention_mask_padded = [] - - for text in batch: - input_ids = text["input_ids"] - last_ind.append(input_ids.shape[0] - 1) - attention_mask = torch.ones(len(input_ids)) - position_ids = torch.arange(len(input_ids)) - - input_ids_padded.append(input_ids) - attention_mask_padded.append(attention_mask) - position_ids_padded.append(position_ids) - - if self.example_inputs_mode == "MASK_POS_KV": - model_inputs = ( - torch.vstack(input_ids_padded), - torch.vstack(attention_mask_padded), - torch.vstack(position_ids_padded), - tuple(self.global_past_key_value), - ) - elif self.example_inputs_mode == "MASK_KV_POS": - model_inputs = ( - torch.vstack(input_ids_padded), - torch.vstack(attention_mask_padded), - tuple(self.global_past_key_value), - torch.vstack(position_ids_padded), - ) - elif self.example_inputs_mode == "KV_MASK": - model_inputs = ( - torch.vstack(input_ids_padded), - tuple(self.global_past_key_value), - torch.vstack(attention_mask_padded), - ) - elif self.example_inputs_mode == "MASK_KV": - model_inputs = ( - torch.vstack(input_ids_padded), - torch.vstack(attention_mask_padded), - tuple(self.global_past_key_value), - ) - elif self.example_inputs_mode == "MASK_KV_ENC": - model_kwargs = { - "attention_mask": torch.vstack(attention_mask_padded), - } - model_kwargs = user_model._prepare_encoder_decoder_kwargs_for_generation( - torch.vstack(input_ids_padded), model_kwargs, "input_ids" - ) - input_ids, example_inputs = user_model._expand_inputs_for_generation( - input_ids=torch.vstack(input_ids_padded), - expand_size=self.num_beams, - is_encoder_decoder=True, - **model_kwargs, - ) - - # need to recompute these - def _get_target_nums(names): - for n in names: - if hasattr(self.user_model.config, n): - return getattr(self.user_model.config, n) - logger.error(f"Not found target {names[0]}") - exit(0) - - num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] - num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] - hidden_size_names = ["hidden_size", "n_embd"] - n_heads = _get_target_nums(num_heads_names) - n_layers = _get_target_nums(num_layers_names) - hidden_size = _get_target_nums(hidden_size_names) - head_dim = int(hidden_size / n_heads) - - # lets get the inputs - input_bs = int(self.batch_size * self.num_beams) - last_hidden_state = example_inputs["encoder_outputs"]["last_hidden_state"] - global_past_key_value = tuple( - [ - ( - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - torch.zeros([1, n_heads, 1, head_dim]).contiguous(), - torch.zeros([1, n_heads, 1, head_dim]).contiguous(), - beam_idx_tmp, - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - user_model.decoder.block[i].layer[1].EncDecAttention.k(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), - user_model.decoder.block[i].layer[1].EncDecAttention.v(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), - beam_idx_tmp, - ) - for i in range(n_layers) - ] - ) - - decoder_input_ids = (torch.zeros(input_bs).to(torch.long).unsqueeze(1)) - model_inputs = ( - decoder_input_ids, - torch.vstack(attention_mask_padded), - tuple(global_past_key_value), - (last_hidden_state,), - ) - else: - raise RuntimeError("Your model does not match existing example inputs used in ipex smooth quant, exiting...") - - return (model_inputs, last_ind) - - - - calib_dataset = load_dataset(self.calib_dataset, split=self.calib_split) - logger.info(f"Dataset loaded: {calib_dataset}") - calib_evaluator = Evaluator( - self.example_inputs_mode, - self.global_past_key_value, - calib_dataset, - self.tokenizer, - batch_size=self.batch_size, - num_beams = self.num_beams, - pad_max = 512 - ) - logger.info(f"Evaluator built: {calib_evaluator}") - calib_dataloader = DataLoader( - calib_evaluator.dataset, - batch_size=1, - shuffle=False, - collate_fn=calib_evaluator.collate_batch, - ) - logger.info("Dataloader ready") - - - from intel_extension_for_pytorch.quantization import prepare, convert - example_inputs = self.get_example_inputs() - qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=self.alpha) - user_model = ipex.llm.optimize( - self.user_model.eval(), - dtype=self.amp_dtype, - quantization_config=qconfig, - inplace=True, - deployment_mode=False, - ) + for i, (model_inputs, last_ind) in enumerate(calib_dataloader): + if i == self.num_calib_iters: + break + prepared_model(*model_inputs) + logger.info("Model calibration completed") - prepared_model = prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True) - logger.info("Model prepared for quantization, observers inserted") + converted_model = convert(prepared_model.eval(), inplace=True).eval() + logger.info("Model converted successfully, exporting the trace") + + with torch.no_grad(), torch.cpu.amp.autocast( + enabled=self.amp_enabled, + dtype=self.amp_dtype + ): + self_jit = torch.jit.trace(converted_model.eval(), example_inputs, strict=False, check_trace=False) + self_jit = torch.jit.freeze(self_jit.eval()) + self_jit.save(self.quantized_model_path) - for i, (model_inputs, last_ind) in enumerate(calib_dataloader): - if i == self.num_calib_iters: - break - prepared_model(*model_inputs) - logger.info("Model calibration completed") + logger.info("IPEX Smooth Quantization has completed successfully") - converted_model = convert(prepared_model.eval(), inplace=True).eval() - logger.info("Model converted successfully, exporting the trace") - - with torch.no_grad(), torch.cpu.amp.autocast( - enabled=self.amp_enabled, - dtype=self.amp_dtype - ): - self_jit = torch.jit.trace(converted_model.eval(), example_inputs, strict=False, check_trace=False) - self_jit = torch.jit.freeze(self_jit.eval()) + else: + # run bf16 model + example_inputs = self.get_example_inputs() + self.user_model = ipex.llm.optimize( + self.user_model.eval(), + dtype = self.amp_dtype, + inplace=True, + deployment_mode=False, + ) + + with torch.no_grad(), torch.cpu.amp.autocast( + enabled=self.amp_enabled, + dtype=self.amp_dtype + ): + self_jit = torch.jit.trace(self.user_model.eval(), example_inputs, strict=False, check_trace=False) + self_jit = torch.jit.freeze(self_jit.eval()) - self_jit.save(self.quantized_model_path) + self_jit.save(self.quantized_model_path) - logger.info("IPEX Smooth Quantization has completed successfully") + logger.info("IPEX bf16 optimization is applied successfully") + + logger.info("Loading the IPEX quantized model") + try: + self_jit = torch.jit.load(self.quantized_model_path) + self_jit = torch.jit.freeze(self_jit.eval()) + except Exception as e: + logger.error("Error: loading the quantized model failed.", e) + exit(0) + + setattr(self.user_model, "trace_graph", self_jit) + logger.info(f"Successfully loaded the Model {model_name} with Intel® Extension for PyTorch*") else: - # run bf16 model - example_inputs = self.get_example_inputs() - self.user_model = ipex.llm.optimize( - self.user_model.eval(), - dtype = self.amp_dtype, - inplace=True, - deployment_mode=False, - ) + # No optimization is applied, but if amx is enabled, it'll be applied during generation routine + logger.warning("No IPEX optimization is applied, Pytorch default autocast will be applied if enabled") - with torch.no_grad(), torch.cpu.amp.autocast( - enabled=self.amp_enabled, - dtype=self.amp_dtype - ): - self_jit = torch.jit.trace(self.user_model.eval(), example_inputs, strict=False, check_trace=False) - self_jit = torch.jit.freeze(self_jit.eval()) - - self_jit.save(self.quantized_model_path) - - logger.info("IPEX bf16 optimization is applied successfully") # set PAD token if self.tokenizer.pad_token is None: - self.tokenizer.pad_token=self.tokenizer.eos_token + if re.search("qwen", self.user_model.config.architectures[0], re.IGNORECASE): + self.tokenizer.pad_token = '<|endoftext|>' + else: + self.tokenizer.pad_token=self.tokenizer.eos_token + - logger.info("Loading the IPEX quantized model") - try: - self_jit = torch.jit.load(self.quantized_model_path) - self_jit = torch.jit.freeze(self_jit.eval()) - except Exception as e: - logger.error("Error: loading the quantized model failed.", e) - exit(0) - - setattr(self.user_model, "trace_graph", self_jit) - logger.info("Successfully loaded the Model %s with Intel® Extension for PyTorch*", ctx.model_name) # Different model need to have their inputs supplied in different order unless we pass dict # For torchserve sending dict is not always possible @@ -497,11 +541,11 @@ def _get_target_nums(names): torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), torch.zeros([1, n_heads, 1, head_dim]).contiguous(), torch.zeros([1, n_heads, 1, head_dim]).contiguous(), - beam_idx_tmp, + self.beam_idx_tmp, torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), torch.zeros([32, 1, n_heads, head_dim]).contiguous(), torch.zeros([32, 1, n_heads, head_dim]).contiguous(), - beam_idx_tmp, + self.beam_idx_tmp, ) for i in range(n_layers) ] @@ -513,8 +557,10 @@ def _get_target_nums(names): ) else: raise RuntimeError("Your model does not match existing example inputs used in ipex quantization, exiting...") - #if hasattr(model, "extra_inputs"): - # example_inputs = example_inputs + model.extra_inputs + # TODO: Figure out how to provide the extra inputs from config + if re.search("chatglm", self.user_model.config.architectures[0], re.IGNORECASE): + extra_inputs = (torch.tensor(True),) + example_inputs = example_inputs + extra_inputs return example_inputs def preprocess(self, requests): @@ -563,8 +609,9 @@ def inference(self, input_batch): dtype=self.amp_dtype ): outputs = self.user_model.generate(input_ids_batch, attention_mask=attention_mask_batch, **self.generate_kwargs) - for i, x in enumerate(outputs): - inferences.append(self.tokenizer.decode(outputs[i], skip_special_tokens=True)) + inferences = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + #for i, x in enumerate(outputs): + # inferences.append(self.tokenizer.decode(outputs[i], skip_special_tokens=True)) return inferences diff --git a/test/pytest/test_ipex_serving.py b/test/pytest/test_ipex_serving.py new file mode 100644 index 0000000000..fd00b1c611 --- /dev/null +++ b/test/pytest/test_ipex_serving.py @@ -0,0 +1,155 @@ +import os +import sys +import json +from pathlib import Path +import subprocess +import yaml + +import pytest +import requests +import test_utils +import torch +#from test_handler import run_inference_using_url_with_data + +from unittest.mock import patch +from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext + +from string import Template +import logging + +from model_archiver.model_archiver_config import ModelArchiverConfig +from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext + + +REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") +snapshot_file_ipex = os.path.join(REPO_ROOT, "test/config_ipex.properties") +prompt_file = os.path.join(REPO_ROOT, "examples/large_models/ipex_llm_int8/sample_text_0.txt") + +#CURR_FILE_PATH = Path(__file__).parent +HANDLER_PATH = os.path.join(REPO_ROOT, "examples/large_models/ipex_llm_int8/") +sys.path.append(HANDLER_PATH) + + +logger = logging.Logger(__name__) + +PROMPTS = ["The capital of France is ",] + +MANAGEMENT_API = "http://localhost:8081" +INFERENCE_API = "http://localhost:8080" + + +xeon_run_cpu_available = False + +cmd = ["python", "-m", "torch.backends.xeon.run_cpu", "--no_python", "pwd"] +r = subprocess.run(cmd) +if r.returncode == 0: + xeon_run_cpu_available = True + +ipex_available = False +cmd = ["python", "-c", "import intel_extension_for_pytorch as ipex"] +r = subprocess.run(cmd) +if r.returncode == 0: + ipex_available = True + +ipex_xeon_run_available = xeon_run_cpu_available and ipex_available + + + +# TODO: download each model we want to serve inside this folder (Change with model name) + +# MODEL_FILE_PATH=HANDLER_PATH/"llama_2" + + + +LLAMA_DEFAULT_CONFIG = f""" + minWorkers: 1 + maxWorkers: 1 + responseTimeout: 1500 + batchSize: 4 + maxBatchDelay: 100 + + handler: + model_name: "meta-llama/Llama-2-7b-hf" + clear_cache_dir: true + quantized_model_path: "best_model.pt" + example_inputs_mode: "MASK_KV_POS" + to_channels_last: false + + # generation params + batch_size: 1 # this batch size is mostly used for calibration, you can leave it as 1 + input_tokens: 1024 + max_new_tokens: 128 + + # Use INT8 bf16 mix + quant_with_amp: true + + # decoding technique + greedy: true + + """ + +def test_handler_no_ipex(tmp_path, mocker): + try: + from llm_handler import IpexLLMHandler + + handler = IpexLLMHandler() + ctx = MockContext() + + model_config_yaml = tmp_path/"model-config.yaml" + #config = LLAMA_DEFAULT_CONFIG.substitute( + # {"nproc": "1", "stream": "true", "compile": compile, "ipex_enable":"false"} + #) + model_config_yaml.write_text(LLAMA_DEFAULT_CONFIG) + os.environ["TS_IPEX_ENABLE"] = "false" + + with open(model_config_yaml, "r") as f: + config = yaml.safe_load(f) + + ctx.model_yaml_config = config + + torch.manual_seed(42) + handler.initialize(ctx) + + # The model with default ipex routine won't have "trace_graph" attribute + assert hasattr(handler.user_model, "trace_graph") == False, "The default Pytorch module must not have 'trace_graph' attribute" + + x = handler.preprocess([{"data": json.dumps(PROMPTS[0])}]) + x = handler.inference(x) + x = handler.postprocess(x) + assert "Paris" in x[0], f"The Answer doesn't seem to be correct!" + + finally: + del handler.user_model + del handler + +def test_handler_ipex_bf16(tmp_path, mocker): + try: + os.environ["TS_IPEX_ENABLE"] = "true" + from llm_handler import IpexLLMHandler + + handler = IpexLLMHandler() + ctx = MockContext() + + model_config_yaml = tmp_path/"model-config.yaml" + #config = LLAMA_DEFAULT_CONFIG.substitute( + # {"nproc": "1", "stream": "true", "compile": compile, "ipex_enable":"false"} + #) + model_config_yaml.write_text(LLAMA_DEFAULT_CONFIG) + + with open(model_config_yaml, "r") as f: + config = yaml.safe_load(f) + + ctx.model_yaml_config = config + + torch.manual_seed(42) + handler.initialize(ctx) + assert hasattr(handler.user_model, "trace_graph") == True, "IPEX optimized bf16 module must have 'trace_graph' attribute" + + x = handler.preprocess([{"data": json.dumps(PROMPTS[0])}]) + x = handler.inference(x) + x = handler.postprocess(x) + assert "Paris" in x[0], f"The Answer doesn't seem to be correct!" + + finally: + del handler.user_model + del handler From 23282951141a6fdd5d76b36a966341f5eb261f84 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 May 2024 05:32:36 +0000 Subject: [PATCH 05/17] Fixing some issues with handler, added test to verify smooth-quant --- examples/large_models/ipex_llm_int8/README.md | 48 +- .../ipex_llm_int8/config.properties | 3 + .../large_models/ipex_llm_int8/llm_handler.py | 510 ++++++++---------- .../model-config-llama2-7b-bf16.yaml | 6 +- .../model-config-llama2-7b-int8-sq.yaml | 4 +- .../model-config-llama2-7b-int8-woq.yaml | 4 +- test/pytest/test_ipex_serving.py | 306 +++++++++-- 7 files changed, 519 insertions(+), 362 deletions(-) create mode 100644 examples/large_models/ipex_llm_int8/config.properties diff --git a/examples/large_models/ipex_llm_int8/README.md b/examples/large_models/ipex_llm_int8/README.md index 7f4b5202ca..bb9e4bc4c0 100644 --- a/examples/large_models/ipex_llm_int8/README.md +++ b/examples/large_models/ipex_llm_int8/README.md @@ -1,30 +1,24 @@ -This example provides an example of serving IPEX-optimized LLMs e.g. ```meta-llama/llama2-7b-hf``` on huggingface. For setting up the Python environment for this example, please refer here: https://github.com/intel/intel-extension-for-pytorch/blob/main/examples/cpu/inference/python/llm/README.md#3-environment-setup +# Serving IPEX Optimized Models +This example provides an example of serving IPEX-optimized LLMs e.g. ```meta-llama/llama2-7b-hf``` on huggingface. For setting up the Python environment for this example, please refer here: https://github.com/intel/intel-extension-for-pytorch/blob/main/examples/cpu/inference/python/llm/README.md#3-environment-setup -You can choose either Weight-only Quantization or Smoothquant path for quantizing the model to ```INT8```. If the ```quant_with_amp``` flag is set to ```true```, it'll use a mix of ```INT8``` and ```bfloat16``` precisions, otherwise, it'll use ```INT8``` and ```FP32``` combination. If neither approaches are enabled, the model runs on ```bfloat16``` precision by default as long as ```quant_with_amp``` is set to ```true```. -There are 3 different example config files; ```model-config-llama2-7b-int8-sq.yaml``` for quantizing with smooth-quant, ```model-config-llama2-7b-int8-woq.yaml``` for quantizing with weight only quantization, and ```model-config-llama2-7b-bf16.yaml``` for running the text generation on bfloat16 precision. -1. Zip everything using the model archiver +1. Run the model archiver ``` -torch-model-archiver --model-name llama2-7b --version 1.0 --handler llm_handler.py --config-file model-config-llama2-7b-int8-woq.yaml +torch-model-archiver --model-name llama2-7b --version 1.0 --handler llm_handler.py --config-file llama2-7b-int8-woq-config.yaml --archive-format no-archive ``` -2. Move archive to model_store +2. Move the model inside model_store ``` mkdir model_store -mv llama2-7b.mar ./model_store -``` - -3. Start the torch server -``` -torchserve --ncs --start --model-store model_store +mv llama2-7b ./model_store ``` -4. From the client, set up batching parameters. I couldn't make it work by putting the max_batch_size and max_batch_delay in config.properties. +3. Start the torch server ``` -curl -X POST "localhost:8081/models?url=llama2-7b.mar&batch_size=4&max_batch_delay=100" +torchserve --ncs --start --model-store model_store models llama2-7b ``` -5. Test the model status +5. Test the model status ``` curl http://localhost:8081/models/llama2-7b ``` @@ -33,3 +27,27 @@ curl http://localhost:8081/models/llama2-7b ``` curl http://localhost:8080/predictions/llama2-7b -T ./sample_text_0.txt ``` +## Model Config +In addition to usual torchserve configurations, you need to enable ipex specific optimization arguments. + +In order to enable IPEX, ```ipex_enable=true``` in the ```config.parameters``` file. If not enabled it will run with default PyTorch with ```auto_mixed_precision``` if enabled. In order to enable ```auto_mixed_precision```, you need to set ```auto_mixed_precision: true``` in model-config file. + +You can choose either Weight-only Quantization or Smoothquant path for quantizing the model to ```INT8```. If the ```quant_with_amp``` flag is set to ```true```, it'll use a mix of ```INT8``` and ```bfloat16``` precisions, otherwise, it'll use ```INT8``` and ```FP32``` combination. If neither approaches are enabled, the model runs on ```bfloat16``` precision by default as long as ```quant_with_amp``` or ```auto_mixed_precision``` is set to ```true```. + +There are 3 different example config files; ```model-config-llama2-7b-int8-sq.yaml``` for quantizing with smooth-quant, ```model-config-llama2-7b-int8-woq.yaml``` for quantizing with weight only quantization, and ```model-config-llama2-7b-bf16.yaml``` for running the text generation on bfloat16 precision. + +### IPEX Weight Only Quantization +
    +
  • weight_type: weight data type for weight only quantization. Options: INT8 or INT4. +
  • lowp_mode: low precision mode for weight only quantization. It indicates data type for computation. +
+ +### IPEX Smooth Quantization + +
    +
  • calibration_dataset, and calibration split: dataset and split to be used for calibrating the model quantization +
  • num_calibration_iters: number of calibration iterations +
  • alpha: a floating point number between 0.0 and 1.0. For more complex smoothquant config, explore IPEX quantization recipes ( https://github.com/intel/intel-extension-for-pytorch/blob/main/examples/cpu/inference/python/llm/single_instance/run_quantization.py ) +
+ +Set ```greedy``` to true if you want to perform greedy search decoding. If set false, beam search of size 4 is performed by default. diff --git a/examples/large_models/ipex_llm_int8/config.properties b/examples/large_models/ipex_llm_int8/config.properties new file mode 100644 index 0000000000..9460c6729b --- /dev/null +++ b/examples/large_models/ipex_llm_int8/config.properties @@ -0,0 +1,3 @@ +ipex_enable=true +cpu_launcher_enable=true +cpu_launcher_args=--node_id 0 diff --git a/examples/large_models/ipex_llm_int8/llm_handler.py b/examples/large_models/ipex_llm_int8/llm_handler.py index a32fa4870f..fe6d5a9b1c 100644 --- a/examples/large_models/ipex_llm_int8/llm_handler.py +++ b/examples/large_models/ipex_llm_int8/llm_handler.py @@ -38,7 +38,7 @@ else: logger.warning("IPEX is not enabled, consider enabling it for best performance on Intel hardware") -class IpexLLMHandler(BaseHandler, ABC): +class IpexLLMHandler(BaseHandler): def __init__(self): super(IpexLLMHandler, self).__init__() @@ -50,42 +50,39 @@ def __init__(self): def initialize(self, ctx: Context): model_name = ctx.model_yaml_config["handler"]["model_name"] # path to quantized model, if we are quantizing on the fly, we'll use this path to save the model - self.clear_cache_dir = ctx.model_yaml_config["handler"]["clear_cache_dir"] - self.quantized_model_path = ctx.model_yaml_config["handler"]["quantized_model_path"] - self.example_inputs_mode = ctx.model_yaml_config["handler"]["example_inputs_mode"] - self.to_channels_last = ctx.model_yaml_config["handler"]["to_channels_last"] + self.clear_cache_dir = ctx.model_yaml_config["handler"].get("clear_cache_dir", False) + self.quantized_model_path = ctx.model_yaml_config["handler"].get("quantized_model_path", "best_model.pt") + self.example_inputs_mode = ctx.model_yaml_config["handler"].get("example_inputs_mode", "MASK_KV_POS") + self.to_channels_last = ctx.model_yaml_config["handler"].get("to_channels_last", False) # generation params - self.batch_size = int(ctx.model_yaml_config["handler"]["batch_size"]) - self.input_tokens = int(ctx.model_yaml_config["handler"]["input_tokens"]) - self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) + self.batch_size = int(ctx.model_yaml_config["handler"].get("batch_size", "1")) + self.input_tokens = int(ctx.model_yaml_config["handler"].get("input_tokens", "1024")) + self.max_new_tokens = int(ctx.model_yaml_config["handler"].get("max_new_tokens", "128")) # use int8 bf16 mix - self.quant_with_amp = ctx.model_yaml_config["handler"]["quant_with_amp"] + self.quant_with_amp = ctx.model_yaml_config["handler"].get("quant_with_amp", True) # WoQ related optimization params - if "ipex_weight_only_quantization" in ctx.model_yaml_config["handler"]: - self.ipex_weight_only_quantization = ctx.model_yaml_config["handler"]["ipex_weight_only_quantization"] - self.woq_dtype = ctx.model_yaml_config["handler"]["woq_dtype"] - self.lowp_mode = ctx.model_yaml_config["handler"]["lowp_mode"] - self.act_quant_mode = ctx.model_yaml_config["handler"]["act_quant_mode"] # This is only relevant for INT4x2 quantization - self.group_size = ctx.model_yaml_config["handler"]["group_size"] - else: - self.ipex_weight_only_quantization = False + self.ipex_weight_only_quantization = ctx.model_yaml_config["handler"].get("ipex_weight_only_quantization", False) + if self.ipex_weight_only_quantization: + self.woq_dtype = ctx.model_yaml_config["handler"].get("woq_dtype", "INT8") + self.lowp_mode = ctx.model_yaml_config["handler"].get("lowp_mode", "BF16") + self.act_quant_mode = ctx.model_yaml_config["handler"].get("act_quant_mode", "PER_IC_BLOCK") # This is only relevant for INT4x2 quantization + self.group_size = int(ctx.model_yaml_config["handler"].get("group_size", "-1")) # SQ related optimization params - if "ipex_smooth_quantization" in ctx.model_yaml_config["handler"]: - self.ipex_smooth_quantization = ctx.model_yaml_config["handler"]["ipex_smooth_quantization"] - self.calib_dataset = ctx.model_yaml_config["handler"]["calibration_dataset"] - self.calib_split = ctx.model_yaml_config["handler"]["calibration_split"] - self.num_calib_iters = int(ctx.model_yaml_config["handler"]["num_calibration_iters"]) - self.alpha = float(ctx.model_yaml_config["handler"]["alpha"]) - else: - self.ipex_smooth_quantization = False + self.ipex_smooth_quantization = ctx.model_yaml_config["handler"].get("ipex_smooth_quantization", False) + if self.ipex_smooth_quantization: + self.num_calib_iters = int(ctx.model_yaml_config["handler"].get("num_calibration_iters", 32)) + self.alpha = float(ctx.model_yaml_config["handler"].get("alpha", 0.9)) + # Keeping outside because we want to use it for tracing as well + self.calib_dataset = ctx.model_yaml_config["handler"].get("calibration_dataset", "NeelNanda/pile-10k") + self.calib_split = ctx.model_yaml_config["handler"].get("calibration_split", "train") # decoding parameters - self.greedy = ctx.model_yaml_config["handler"]["greedy"] + self.greedy = ctx.model_yaml_config["handler"].get("greedy", False) # amp datatype if self.quant_with_amp: @@ -121,7 +118,7 @@ def initialize(self, ctx: Context): # load model and tokenizer, # We need special provision for t5 because it's seq2seq model, and can not be loaded with AutoModelForCausalLM if re.search("t5", config.architectures[0], re.IGNORECASE): - self.user_model = T5ForConditionalGeneration.from_pretrained(model_name, config=config, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float) + self.user_model = T5ForConditionalGeneration.from_pretrained(model_name, config=config, low_cpu_mem_usage=True, torch_dtype=torch.float) input_ids = torch.ones(32).to(torch.long).unsqueeze(0) attention_mask = torch.ones_like(input_ids) dummy_inputs = self.user_model.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask) @@ -151,7 +148,7 @@ def _get_target_nums(names): if hasattr(self.user_model.config, n): return getattr(self.user_model.config, n) logger.error(f"Not found target {names[0]}") - exit(0) + exit(1) num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] @@ -171,6 +168,176 @@ def _get_target_nums(names): ] logger.info(f"num_attention_heads: {n_heads}, num_hidden_layers: {n_layers}, hidden size: {hidden_size}, head_dim: {head_dim}") + + logger.info("Preparing the dataset for calibration and tracing") + class Evaluator: + def __init__(self, + user_model, + example_inputs_mode, + global_past_key_value, + dataset, tokenizer, + batch_size=1, + num_beams=1, + pad_val=1, + pad_max=512): + self.user_model = user_model + self.example_inputs_mode = example_inputs_mode + self.global_past_key_value = global_past_key_value + self.dataset = dataset + self.tokenizer = tokenizer + self.batch_size = batch_size + self.num_beams = num_beams + + + self.pad_val = pad_val + self.pad_max = pad_max + self.dataset = self.dataset.map(self.tokenize_function, batched=True) + self.dataset.set_format(type="torch", columns=["input_ids"]) + + @torch.no_grad() + def tokenize_function(self, examples): + if "text" in examples: + example = self.tokenizer(examples["text"]) + elif "prompt" in examples: + example = self.tokenizer(examples["prompt"]) + elif "code" in examples: + example = self.tokenizer(examples["code"]) + return example + + + @torch.no_grad() + def collate_batch(self, batch): + position_ids_padded = [] + input_ids_padded = [] + last_ind = [] + attention_mask_padded = [] + + for text in batch: + input_ids = text["input_ids"] + last_ind.append(input_ids.shape[0] - 1) + attention_mask = torch.ones(len(input_ids)) + position_ids = torch.arange(len(input_ids)) + + input_ids_padded.append(input_ids) + attention_mask_padded.append(attention_mask) + position_ids_padded.append(position_ids) + + if self.example_inputs_mode == "MASK_POS_KV": + model_inputs = ( + torch.vstack(input_ids_padded), + torch.vstack(attention_mask_padded), + torch.vstack(position_ids_padded), + tuple(self.global_past_key_value), + ) + elif self.example_inputs_mode == "MASK_KV_POS": + model_inputs = ( + torch.vstack(input_ids_padded), + torch.vstack(attention_mask_padded), + tuple(self.global_past_key_value), + torch.vstack(position_ids_padded), + ) + elif self.example_inputs_mode == "KV_MASK": + model_inputs = ( + torch.vstack(input_ids_padded), + tuple(self.global_past_key_value), + torch.vstack(attention_mask_padded), + ) + elif self.example_inputs_mode == "MASK_KV": + model_inputs = ( + torch.vstack(input_ids_padded), + torch.vstack(attention_mask_padded), + tuple(self.global_past_key_value), + ) + elif self.example_inputs_mode == "MASK_KV_ENC": + model_kwargs = { + "attention_mask": torch.vstack(attention_mask_padded), + } + model_kwargs = self.user_model._prepare_encoder_decoder_kwargs_for_generation( + torch.vstack(input_ids_padded), model_kwargs, "input_ids" + ) + input_ids, example_inputs = self.user_model._expand_inputs_for_generation( + input_ids=torch.vstack(input_ids_padded), + expand_size=self.num_beams, + is_encoder_decoder=True, + **model_kwargs, + ) + + # need to recompute these + def _get_target_nums(names): + for n in names: + if hasattr(self.user_model.config, n): + return getattr(self.user_model.config, n) + logger.error(f"Not found target {names[0]}") + exit(1) + + num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] + num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] + hidden_size_names = ["hidden_size", "n_embd"] + n_heads = _get_target_nums(num_heads_names) + n_layers = _get_target_nums(num_layers_names) + hidden_size = _get_target_nums(hidden_size_names) + head_dim = int(hidden_size / n_heads) + + # lets get the inputs + beam_idx_tmp = torch.zeros((2048, int(self.batch_size * self.num_beams)), dtype=torch.long).contiguous() + input_bs = int(self.batch_size * self.num_beams) + last_hidden_state = example_inputs["encoder_outputs"]["last_hidden_state"] + global_past_key_value = tuple( + [ + ( + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + torch.zeros([1, n_heads, 1, head_dim]).contiguous(), + beam_idx_tmp, + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + self.user_model.decoder.block[i].layer[1].EncDecAttention.k(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), + self.user_model.decoder.block[i].layer[1].EncDecAttention.v(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), + beam_idx_tmp, + ) + for i in range(n_layers) + ] + ) + + decoder_input_ids = (torch.zeros(input_bs).to(torch.long).unsqueeze(1)) + model_inputs = ( + decoder_input_ids, + torch.vstack(attention_mask_padded), + tuple(global_past_key_value), + (last_hidden_state,), + ) + else: + raise RuntimeError("Your model does not match existing example inputs used in ipex smooth quant, exiting...") + + # Some models require extra inputs + if re.search("chatglm", self.user_model.config.architectures[0], re.IGNORECASE): + extra_inputs = (torch.tensor(True),) + model_inputs = model_inputs + extra_inputs + + return (model_inputs, last_ind) + + + + calib_dataset = load_dataset(self.calib_dataset, split=self.calib_split) + logger.info(f"Dataset loaded: {calib_dataset}") + calib_evaluator = Evaluator( + self.user_model, + self.example_inputs_mode, + self.global_past_key_value, + calib_dataset, + self.tokenizer, + batch_size=self.batch_size, + num_beams = self.num_beams, + pad_max = int(self.input_tokens) if re.search("t5", config.architectures[0], re.IGNORECASE) else 512 + ) + logger.info(f"Evaluator built: {calib_evaluator}") + + self.calib_dataloader = DataLoader( + calib_evaluator.dataset, + batch_size=1, + shuffle=False, + collate_fn=calib_evaluator.collate_batch, + ) + logger.info("Dataloader is built successfully!") if IPEX_ENABLE: """ @@ -180,6 +347,19 @@ def _get_target_nums(names): (3) ipex bfloat16 optimization if neither is quantization is enabled (4) throws error if both 1 and 2 are enabled """ + + def trace_and_export(model): + example_inputs = self.get_example_inputs() + + with torch.no_grad(), torch.cpu.amp.autocast( + enabled=self.amp_enabled, + dtype=self.amp_dtype + ): + self_jit = torch.jit.trace(model.eval(), example_inputs, strict=False, check_trace=False) + self_jit = torch.jit.freeze(self_jit.eval()) + + self_jit.save(self.quantized_model_path) + if self.ipex_smooth_quantization and self.ipex_weight_only_quantization: logger.error("Can't enable both SQ and WoQ, enable only one of them") exit(1) @@ -232,175 +412,12 @@ def _get_target_nums(names): deployment_mode=False, ) logger.info("The model conversion completed, now tracing the quantized model") - - example_inputs = self.get_example_inputs() - - with torch.no_grad(), torch.cpu.amp.autocast( - enabled=self.amp_enabled, - dtype=self.amp_dtype - ): - self_jit = torch.jit.trace(self.user_model.eval(), example_inputs, strict=False, check_trace=False) - self_jit = torch.jit.freeze(self_jit.eval()) - - self_jit.save(self.quantized_model_path) + + trace_and_export(self.user_model) logger.info("The IPEX Weight only quantization has been completed successfully") elif self.ipex_smooth_quantization: - class Evaluator: - def __init__(self, example_inputs_mode, global_past_key_value, dataset, tokenizer, batch_size=1, num_beams=1, pad_val=1, pad_max=512): - self.example_inputs_mode = example_inputs_mode - self.global_past_key_value = global_past_key_value - self.dataset = dataset - self.tokenizer = tokenizer - self.batch_size = batch_size - self.num_beams = num_beams - - - self.pad_val = pad_val - self.pad_max = pad_max - self.dataset = self.dataset.map(self.tokenize_function, batched = True, num_proc=2) - self.dataset.set_format(type="torch", columns=["input_ids"]) - - @torch.no_grad() - def tokenize_function(self, examples): - if "prompt" in examples: - example = self.tokenizer(examples["prompt"]) - elif "text" in examples: - example = self.tokenizer(examples["text"]) - elif "code" in examples: - example = self.tokenizer(examples["code"]) - return example - - - @torch.no_grad() - def collate_batch(self, batch): - position_ids_padded = [] - input_ids_padded = [] - last_ind = [] - attention_mask_padded = [] - - for text in batch: - input_ids = text["input_ids"] - last_ind.append(input_ids.shape[0] - 1) - attention_mask = torch.ones(len(input_ids)) - position_ids = torch.arange(len(input_ids)) - - input_ids_padded.append(input_ids) - attention_mask_padded.append(attention_mask) - position_ids_padded.append(position_ids) - - if self.example_inputs_mode == "MASK_POS_KV": - model_inputs = ( - torch.vstack(input_ids_padded), - torch.vstack(attention_mask_padded), - torch.vstack(position_ids_padded), - tuple(self.global_past_key_value), - ) - elif self.example_inputs_mode == "MASK_KV_POS": - model_inputs = ( - torch.vstack(input_ids_padded), - torch.vstack(attention_mask_padded), - tuple(self.global_past_key_value), - torch.vstack(position_ids_padded), - ) - elif self.example_inputs_mode == "KV_MASK": - model_inputs = ( - torch.vstack(input_ids_padded), - tuple(self.global_past_key_value), - torch.vstack(attention_mask_padded), - ) - elif self.example_inputs_mode == "MASK_KV": - model_inputs = ( - torch.vstack(input_ids_padded), - torch.vstack(attention_mask_padded), - tuple(self.global_past_key_value), - ) - elif self.example_inputs_mode == "MASK_KV_ENC": - model_kwargs = { - "attention_mask": torch.vstack(attention_mask_padded), - } - model_kwargs = user_model._prepare_encoder_decoder_kwargs_for_generation( - torch.vstack(input_ids_padded), model_kwargs, "input_ids" - ) - input_ids, example_inputs = user_model._expand_inputs_for_generation( - input_ids=torch.vstack(input_ids_padded), - expand_size=self.num_beams, - is_encoder_decoder=True, - **model_kwargs, - ) - - # need to recompute these - def _get_target_nums(names): - for n in names: - if hasattr(self.user_model.config, n): - return getattr(self.user_model.config, n) - logger.error(f"Not found target {names[0]}") - exit(0) - - num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] - num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] - hidden_size_names = ["hidden_size", "n_embd"] - n_heads = _get_target_nums(num_heads_names) - n_layers = _get_target_nums(num_layers_names) - hidden_size = _get_target_nums(hidden_size_names) - head_dim = int(hidden_size / n_heads) - - # lets get the inputs - input_bs = int(self.batch_size * self.num_beams) - last_hidden_state = example_inputs["encoder_outputs"]["last_hidden_state"] - global_past_key_value = tuple( - [ - ( - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - torch.zeros([1, n_heads, 1, head_dim]).contiguous(), - torch.zeros([1, n_heads, 1, head_dim]).contiguous(), - self.beam_idx_tmp, - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - user_model.decoder.block[i].layer[1].EncDecAttention.k(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), - user_model.decoder.block[i].layer[1].EncDecAttention.v(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), - self.beam_idx_tmp, - ) - for i in range(n_layers) - ] - ) - - decoder_input_ids = (torch.zeros(input_bs).to(torch.long).unsqueeze(1)) - model_inputs = ( - decoder_input_ids, - torch.vstack(attention_mask_padded), - tuple(global_past_key_value), - (last_hidden_state,), - ) - else: - raise RuntimeError("Your model does not match existing example inputs used in ipex smooth quant, exiting...") - - return (model_inputs, last_ind) - - - - calib_dataset = load_dataset(self.calib_dataset, split=self.calib_split) - logger.info(f"Dataset loaded: {calib_dataset}") - calib_evaluator = Evaluator( - self.example_inputs_mode, - self.global_past_key_value, - calib_dataset, - self.tokenizer, - batch_size=self.batch_size, - num_beams = self.num_beams, - pad_max = int(self.input_tokens) if re.search("t5", config.architectures[0], re.IGNORECASE) else 512 - ) - logger.info(f"Evaluator built: {calib_evaluator}") - - calib_dataloader = DataLoader( - calib_evaluator.dataset, - batch_size=1, - shuffle=False, - collate_fn=calib_evaluator.collate_batch, - ) - logger.info("Dataloader ready") - - from intel_extension_for_pytorch.quantization import prepare, convert example_inputs = self.get_example_inputs() qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=self.alpha) @@ -416,7 +433,7 @@ def _get_target_nums(names): logger.info("Model prepared for quantization, observers inserted") - for i, (model_inputs, last_ind) in enumerate(calib_dataloader): + for i, (model_inputs, last_ind) in enumerate(self.calib_dataloader): if i == self.num_calib_iters: break prepared_model(*model_inputs) @@ -424,16 +441,9 @@ def _get_target_nums(names): converted_model = convert(prepared_model.eval(), inplace=True).eval() logger.info("Model converted successfully, exporting the trace") - - with torch.no_grad(), torch.cpu.amp.autocast( - enabled=self.amp_enabled, - dtype=self.amp_dtype - ): - self_jit = torch.jit.trace(converted_model.eval(), example_inputs, strict=False, check_trace=False) - self_jit = torch.jit.freeze(self_jit.eval()) - - self_jit.save(self.quantized_model_path) + trace_and_export(converted_model) + logger.info("IPEX Smooth Quantization has completed successfully") else: @@ -445,16 +455,8 @@ def _get_target_nums(names): inplace=True, deployment_mode=False, ) - - with torch.no_grad(), torch.cpu.amp.autocast( - enabled=self.amp_enabled, - dtype=self.amp_dtype - ): - self_jit = torch.jit.trace(self.user_model.eval(), example_inputs, strict=False, check_trace=False) - self_jit = torch.jit.freeze(self_jit.eval()) - - self_jit.save(self.quantized_model_path) - + + trace_and_export(self.user_model) logger.info("IPEX bf16 optimization is applied successfully") logger.info("Loading the IPEX quantized model") @@ -482,86 +484,10 @@ def _get_target_nums(names): - # Different model need to have their inputs supplied in different order unless we pass dict - # For torchserve sending dict is not always possible - # This function reorders the input ids, masks, and kv cache based on models + # we are going to use data collator we built to generate example input def get_example_inputs(self): - example_inputs = None - input_ids = torch.ones(32).to(torch.long) - attention_mask = torch.ones(len(input_ids)) - if self.example_inputs_mode == "MASK_POS_KV": - position_ids = torch.arange(len(input_ids)) - example_inputs = ( - input_ids.unsqueeze(0), - attention_mask.unsqueeze(0), - position_ids.unsqueeze(0), - tuple(self.global_past_key_value), - ) - elif self.example_inputs_mode == "MASK_KV_POS": - position_ids = torch.arange(len(input_ids)) - example_inputs = ( - input_ids.unsqueeze(0), - attention_mask.unsqueeze(0), - tuple(self.global_past_key_value), - position_ids.unsqueeze(0), - ) - elif self.example_inputs_mode == "KV_MASK": - example_inputs = ( - input_ids.unsqueeze(0), - tuple(self.global_past_key_value), - attention_mask.unsqueeze(0), - ) - elif self.example_inputs_mode == "MASK_KV": - example_inputs = ( - input_ids.unsqueeze(0), - attention_mask.unsqueeze(0), - tuple(self.global_past_key_value), - ) - elif self.example_inputs_mode == "MASK_KV_ENC": - last_hidden_state = torch.rand([1, 32, 2048]) - - #need to recompute these - def _get_target_nums(names): - for n in names: - if hasattr(self.user_model.config, n): - return getattr(self.user_model.config, n) - logger.error(f"Not found target {names[0]}") - exit(0) - - num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] - num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] - hidden_size_names = ["hidden_size", "n_embd"] - n_heads = _get_target_nums(num_heads_names) - n_layers = _get_target_nums(num_layers_names) - hidden_size = _get_target_nums(hidden_size_names) - head_dim = int(hidden_size / n_heads) - - global_past_key_value = [ - ( - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - torch.zeros([1, n_heads, 1, head_dim]).contiguous(), - torch.zeros([1, n_heads, 1, head_dim]).contiguous(), - self.beam_idx_tmp, - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - torch.zeros([32, 1, n_heads, head_dim]).contiguous(), - torch.zeros([32, 1, n_heads, head_dim]).contiguous(), - self.beam_idx_tmp, - ) - for i in range(n_layers) - ] - example_inputs = ( - torch.ones(1).to(torch.long).unsqueeze(0), - attention_mask.unsqueeze(0), - tuple(global_past_key_value), - (last_hidden_state,), - ) - else: - raise RuntimeError("Your model does not match existing example inputs used in ipex quantization, exiting...") - # TODO: Figure out how to provide the extra inputs from config - if re.search("chatglm", self.user_model.config.architectures[0], re.IGNORECASE): - extra_inputs = (torch.tensor(True),) - example_inputs = example_inputs + extra_inputs - return example_inputs + (model_inputs, last_ind) = next(iter(self.calib_dataloader)) + return model_inputs def preprocess(self, requests): input_ids_batch = None diff --git a/examples/large_models/ipex_llm_int8/model-config-llama2-7b-bf16.yaml b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-bf16.yaml index 92dc955b96..f42bc644e5 100644 --- a/examples/large_models/ipex_llm_int8/model-config-llama2-7b-bf16.yaml +++ b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-bf16.yaml @@ -1,20 +1,22 @@ minWorkers: 1 maxWorkers: 1 responseTimeout: 1500 +BatchSize: 4 +maxBatchDelay: 100 handler: model_name: "meta-llama/Llama-2-7b-hf" + clear_cache_dir: true # removes the quantized model if it already exists quantized_model_path: "best_model.pt" example_inputs_mode: "MASK_KV_POS" to_channels_last: false # generation params batch_size: 1 - max_context_length: 2048 input_tokens: 1024 max_new_tokens: 128 - # Use INT8 bf16 mix + # Use INT8+bf16 mix quant_with_amp: true # decoding technique diff --git a/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-sq.yaml b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-sq.yaml index a82016abda..a39b8e0345 100644 --- a/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-sq.yaml +++ b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-sq.yaml @@ -1,16 +1,18 @@ minWorkers: 1 maxWorkers: 1 responseTimeout: 1500 +batchSize: 4 +maxBatchDelay: 100 handler: model_name: "meta-llama/Llama-2-7b-hf" + clear_cache_dir: true quantized_model_path: "best_model.pt" example_inputs_mode: "MASK_KV_POS" to_channels_last: false # generation params batch_size: 1 - max_context_length: 2048 input_tokens: 1024 max_new_tokens: 128 diff --git a/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-woq.yaml b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-woq.yaml index 29cbda64f7..04f43fcfdd 100644 --- a/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-woq.yaml +++ b/examples/large_models/ipex_llm_int8/model-config-llama2-7b-int8-woq.yaml @@ -1,16 +1,18 @@ minWorkers: 1 maxWorkers: 1 responseTimeout: 1500 +batchSize: 4 +maxBatchDelay: 100 handler: model_name: "meta-llama/Llama-2-7b-hf" + clear_cache_dir: true quantized_model_path: "best_model.pt" example_inputs_mode: "MASK_KV_POS" to_channels_last: false # generation params batch_size: 1 - max_context_length: 2048 input_tokens: 1024 max_new_tokens: 128 diff --git a/test/pytest/test_ipex_serving.py b/test/pytest/test_ipex_serving.py index fd00b1c611..8fd7168e26 100644 --- a/test/pytest/test_ipex_serving.py +++ b/test/pytest/test_ipex_serving.py @@ -2,6 +2,7 @@ import sys import json from pathlib import Path +import shutil import subprocess import yaml @@ -23,6 +24,7 @@ REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") snapshot_file_ipex = os.path.join(REPO_ROOT, "test/config_ipex.properties") +default_ts_config = os.path.join(REPO_ROOT, "test/config_ts.properties") prompt_file = os.path.join(REPO_ROOT, "examples/large_models/ipex_llm_int8/sample_text_0.txt") #CURR_FILE_PATH = Path(__file__).parent @@ -53,12 +55,57 @@ ipex_xeon_run_available = xeon_run_cpu_available and ipex_available +@pytest.fixture(scope="module") +def model_name(): + yield "llama2" -# TODO: download each model we want to serve inside this folder (Change with model name) +@pytest.fixture(scope="module") +def work_dir(tmp_path_factory, model_name): + return Path(tmp_path_factory.mktemp(model_name)) -# MODEL_FILE_PATH=HANDLER_PATH/"llama_2" +# @pytest.fixture(scope="module", name="mar_file_path") +def create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): + mar_file_path = work_dir.joinpath(model_name + ".mar") + handler_file = os.path.join(HANDLER_PATH, "llm_handler.py") + assert(Path(handler_file).exists()) + + config = ModelArchiverConfig( + model_name=model_name, + version="1.0", + serialized_file=None, + model_file=None, + handler=handler_file, + extra_files=None, + export_path=work_dir, + requirements_file=None, + runtime="python", + force=False, + archive_format="default", + config_file=model_config_yaml_file.as_posix(), + ) + + with patch("archiver.ArgParser.export_model_args_parser", return_value=config): + model_archiver.generate_model_archive() + + assert mar_file_path.exists() + + return mar_file_path.as_posix() + +def run_inference_with_prompt(prompt_file, model_name): + model_url = f"{INFERENCE_API}/predictions/{model_name}" + response = run_inference_using_url_with_data(model_url, prompt_file) + return response + +def start_torchserve(ts_config_file): + + # start the torchserve + test_utils.start_torchserve( + model_store=test_utils.MODEL_STORE, + snapshot_file=ts_config_file, + gen_mar=False + ) LLAMA_DEFAULT_CONFIG = f""" @@ -88,68 +135,225 @@ """ -def test_handler_no_ipex(tmp_path, mocker): - try: - from llm_handler import IpexLLMHandler +LLAMA_CONFIG_WOQ = f""" + minWorkers: 1 + maxWorkers: 1 + responseTimeout: 1500 + batchSize: 4 + maxBatchDelay: 100 + + handler: + model_name: "meta-llama/Llama-2-7b-hf" + clear_cache_dir: true + quantized_model_path: "best_model.pt" + example_inputs_mode: "MASK_KV_POS" + to_channels_last: false + + # generation params + batch_size: 1 + input_tokens: 1024 + max_new_tokens: 128 + + # Use INT8 bf16 mix + quant_with_amp: true + + # Woq params + ipex_weight_only_quantization: true + woq_dtype: "INT8" + lowp_mode: "BF16" + act_quant_mode: "PER_IC_BLOCK" + group_size: -1 + + # decoding technique + greedy: true + """ - handler = IpexLLMHandler() - ctx = MockContext() +LLAMA_CONFIG_SQ = f""" + minWorkers: 1 + maxWorkers: 1 + responseTimeout: 1500 + batchSize: 4 + maxBatchDelay: 100 + + handler: + model_name: "meta-llama/Llama-2-7b-hf" + clear_cache_dir: true + quantized_model_path: "best_model.pt" + example_inputs_mode: "MASK_KV_POS" + to_channels_last: false + + # generation params + batch_size: 1 + input_tokens: 1024 + max_new_tokens: 128 + + # use bf16-int8 mix + quant_with_amp: true + + # SQ quantization params + ipex_smooth_quantization: true + calibration_dataset: "NeelNanda/pile-10k" + calibration_split: "train" + num_calibration_iters: 32 + alpha: 0.9 + + # decoding technique + greedy: true - model_config_yaml = tmp_path/"model-config.yaml" - #config = LLAMA_DEFAULT_CONFIG.substitute( - # {"nproc": "1", "stream": "true", "compile": compile, "ipex_enable":"false"} - #) - model_config_yaml.write_text(LLAMA_DEFAULT_CONFIG) - os.environ["TS_IPEX_ENABLE"] = "false" + """ - with open(model_config_yaml, "r") as f: - config = yaml.safe_load(f) +""" +outline of the tests: + 1. edit the config + 2. create mar file + 3. start torchserve + 4. test connection + 5. test response correctness +""" + +def test_handler_default_pytorch(work_dir, model_archiver): + test_utils.torchserve_cleanup() + # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): + model_config_yaml = work_dir / "model-config.yaml" + model_config_yaml.write_text(LLAMA_DEFAULT_CONFIG) + + # Create mar file + model_name = "llama2_no_ipex" + mar_file_path = create_mar_file(work_dir, model_archiver, model_name, model_config_yaml) + os.makedirs(os.path.dirname(test_utils.MODEL_STORE), exist_ok=True) + shutil.move(mar_file_path, test_utils.MODEL_STORE) + + # start torchserve server + start_torchserve(default_ts_config) + + # load the model + model_url = f"{MANAGEMENT_API}/models?url={model_name}.mar" + requests.post(model_url) + + # query model info + model_url = f"{MANAGEMENT_API}/models/{model_name}" + response = requests.get(model_url) + assert response.status_code == 200, "The Model failed the with default Pytorch" + + # send prompts to the model + model_url = f"{INFERENCE_API}/predictions/{model_name}" + response = requests.post(url=model_url, + data=json.dumps(PROMPTS[0],), + ) + + assert response.status_code == 200, "The model failed to generate text from prompt!" + assert "Paris" in response.text, "The response doesn't seem to be correct!" + - ctx.model_yaml_config = config + test_utils.torchserve_cleanup() - torch.manual_seed(42) - handler.initialize(ctx) - # The model with default ipex routine won't have "trace_graph" attribute - assert hasattr(handler.user_model, "trace_graph") == False, "The default Pytorch module must not have 'trace_graph' attribute" +def test_handler_ipex_bf16(work_dir, model_archiver): + test_utils.torchserve_cleanup() + # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): + model_config_yaml = work_dir / "model-config.yaml" + model_config_yaml.write_text(LLAMA_DEFAULT_CONFIG) - x = handler.preprocess([{"data": json.dumps(PROMPTS[0])}]) - x = handler.inference(x) - x = handler.postprocess(x) - assert "Paris" in x[0], f"The Answer doesn't seem to be correct!" + # Create mar file + model_name = "llama2_ipex_bf16" + mar_file_path = create_mar_file(work_dir, model_archiver, model_name, model_config_yaml) + os.makedirs(os.path.dirname(test_utils.MODEL_STORE), exist_ok=True) + shutil.move(mar_file_path, test_utils.MODEL_STORE) - finally: - del handler.user_model - del handler + # start torchserve server + start_torchserve(snapshot_file_ipex) + + # load the model + model_url = f"{MANAGEMENT_API}/models?url={model_name}.mar" + requests.post(model_url) + + # query model info + model_url = f"{MANAGEMENT_API}/models/{model_name}" + response = requests.get(model_url) + assert response.status_code == 200, "The Model failed the with default Pytorch" + + # send prompts to the model + model_url = f"{INFERENCE_API}/predictions/{model_name}" + response = requests.post(url=model_url, + data=json.dumps(PROMPTS[0],), + ) + + assert response.status_code == 200, "The model failed to generate text from prompt!" + assert "Paris" in response.text, "The response doesn't seem to be correct!" + + + test_utils.torchserve_cleanup() -def test_handler_ipex_bf16(tmp_path, mocker): - try: - os.environ["TS_IPEX_ENABLE"] = "true" - from llm_handler import IpexLLMHandler - handler = IpexLLMHandler() - ctx = MockContext() +def test_handler_ipex_int8_woq(work_dir, model_archiver): + test_utils.torchserve_cleanup() + # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): + model_config_yaml = work_dir / "model-config.yaml" + model_config_yaml.write_text(LLAMA_CONFIG_WOQ) - model_config_yaml = tmp_path/"model-config.yaml" - #config = LLAMA_DEFAULT_CONFIG.substitute( - # {"nproc": "1", "stream": "true", "compile": compile, "ipex_enable":"false"} - #) - model_config_yaml.write_text(LLAMA_DEFAULT_CONFIG) + # Create mar file + model_name = "llama2_ipex_int8_woq" + mar_file_path = create_mar_file(work_dir, model_archiver, model_name, model_config_yaml) + os.makedirs(os.path.dirname(test_utils.MODEL_STORE), exist_ok=True) + shutil.move(mar_file_path, test_utils.MODEL_STORE) - with open(model_config_yaml, "r") as f: - config = yaml.safe_load(f) + # start torchserve server + start_torchserve(snapshot_file_ipex) - ctx.model_yaml_config = config + # load the model + model_url = f"{MANAGEMENT_API}/models?url={model_name}.mar" + requests.post(model_url) + + # query model info + model_url = f"{MANAGEMENT_API}/models/{model_name}" + response = requests.get(model_url) + assert response.status_code == 200, "The Model failed the with default Pytorch" + + # send prompts to the model + model_url = f"{INFERENCE_API}/predictions/{model_name}" + response = requests.post(url=model_url, + data=json.dumps(PROMPTS[0],), + ) + + assert response.status_code == 200, "The model failed to generate text from prompt!" + assert "Paris" in response.text, "The response doesn't seem to be correct!" + - torch.manual_seed(42) - handler.initialize(ctx) - assert hasattr(handler.user_model, "trace_graph") == True, "IPEX optimized bf16 module must have 'trace_graph' attribute" + test_utils.torchserve_cleanup() - x = handler.preprocess([{"data": json.dumps(PROMPTS[0])}]) - x = handler.inference(x) - x = handler.postprocess(x) - assert "Paris" in x[0], f"The Answer doesn't seem to be correct!" - finally: - del handler.user_model - del handler +def test_handler_ipex_int8_sq(work_dir, model_archiver): + test_utils.torchserve_cleanup() + # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): + model_config_yaml = work_dir / "model-config.yaml" + model_config_yaml.write_text(LLAMA_CONFIG_SQ) + + # Create mar file + model_name = "llama2_ipex_int8_sq" + mar_file_path = create_mar_file(work_dir, model_archiver, model_name, model_config_yaml) + os.makedirs(os.path.dirname(test_utils.MODEL_STORE), exist_ok=True) + shutil.move(mar_file_path, test_utils.MODEL_STORE) + + # start torchserve server + start_torchserve(snapshot_file_ipex) + + # load the model + model_url = f"{MANAGEMENT_API}/models?url={model_name}.mar" + requests.post(model_url) + + # query model info + model_url = f"{MANAGEMENT_API}/models/{model_name}" + response = requests.get(model_url) + assert response.status_code == 200, "The Model failed the with default Pytorch" + + # send prompts to the model + model_url = f"{INFERENCE_API}/predictions/{model_name}" + response = requests.post(url=model_url, + data=json.dumps(PROMPTS[0],), + ) + + assert response.status_code == 200, "The model failed to generate text from prompt!" + assert "Paris" in response.text, "The response doesn't seem to be correct!" + + + test_utils.torchserve_cleanup() From 85ba194dfb1e87fffbe73d398d92fa87b1ce47de Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 May 2024 05:59:15 +0000 Subject: [PATCH 06/17] adding auto_mixed_precision flag to config --- examples/large_models/ipex_llm_int8/llm_handler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/large_models/ipex_llm_int8/llm_handler.py b/examples/large_models/ipex_llm_int8/llm_handler.py index fe6d5a9b1c..0f2505b2b7 100644 --- a/examples/large_models/ipex_llm_int8/llm_handler.py +++ b/examples/large_models/ipex_llm_int8/llm_handler.py @@ -59,6 +59,9 @@ def initialize(self, ctx: Context): self.batch_size = int(ctx.model_yaml_config["handler"].get("batch_size", "1")) self.input_tokens = int(ctx.model_yaml_config["handler"].get("input_tokens", "1024")) self.max_new_tokens = int(ctx.model_yaml_config["handler"].get("max_new_tokens", "128")) + + # enable auto mix precision + self.auto_mixed_precision = ctx.model_yaml_config["handler"].get("auto_mixed_precision", True) # use int8 bf16 mix self.quant_with_amp = ctx.model_yaml_config["handler"].get("quant_with_amp", True) @@ -85,7 +88,7 @@ def initialize(self, ctx: Context): self.greedy = ctx.model_yaml_config["handler"].get("greedy", False) # amp datatype - if self.quant_with_amp: + if self.quant_with_amp or self.auto_mixed_precision: self.amp_enabled = True self.amp_dtype = torch.bfloat16 else: From eca2b0a9887e932fe4691f83d1efe1b6ada5a23f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 May 2024 06:15:32 +0000 Subject: [PATCH 07/17] Removing min_new_tokens from generation config --- examples/large_models/ipex_llm_int8/llm_handler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/large_models/ipex_llm_int8/llm_handler.py b/examples/large_models/ipex_llm_int8/llm_handler.py index 0f2505b2b7..67172436f6 100644 --- a/examples/large_models/ipex_llm_int8/llm_handler.py +++ b/examples/large_models/ipex_llm_int8/llm_handler.py @@ -103,7 +103,6 @@ def initialize(self, ctx: Context): temperature=0.9, num_beams=self.num_beams, max_new_tokens=self.max_new_tokens, - min_new_tokens=self.max_new_tokens, ) # device From 37f16c94f0cea57ed03d1317a93e2c486fafb3f6 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 15 May 2024 12:58:46 -0700 Subject: [PATCH 08/17] fix lint --- ts_scripts/spellcheck_conf/wordlist.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 142909f201..3f6a8d5d8a 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1239,3 +1239,11 @@ vllm sql TimeUnit Aopen +Smoothquant +iters +lowp +precisions +quant +quantizing +smoothquant +woq From 2e852f596f4a9de16b4c5b3082bc3a8b35b913c0 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 15 May 2024 12:59:56 -0700 Subject: [PATCH 09/17] lint --- examples/large_models/ipex_llm_int8/llm_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/large_models/ipex_llm_int8/llm_handler.py b/examples/large_models/ipex_llm_int8/llm_handler.py index 67172436f6..a9258bdaa5 100644 --- a/examples/large_models/ipex_llm_int8/llm_handler.py +++ b/examples/large_models/ipex_llm_int8/llm_handler.py @@ -545,3 +545,4 @@ def inference(self, input_batch): def postprocess(self, inference_output): return inference_output + From 12b75ccce8041f1d5e74cd97e110e8ba876ce073 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 15 May 2024 13:01:39 -0700 Subject: [PATCH 10/17] lint --- examples/large_models/ipex_llm_int8/llm_handler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/large_models/ipex_llm_int8/llm_handler.py b/examples/large_models/ipex_llm_int8/llm_handler.py index a9258bdaa5..36ff76ce5e 100644 --- a/examples/large_models/ipex_llm_int8/llm_handler.py +++ b/examples/large_models/ipex_llm_int8/llm_handler.py @@ -1,7 +1,5 @@ import os import logging -from abc import ABC -from pathlib import Path import re import torch From 03f8be855aaf72d86e8f2d3be5d9291de35b2e68 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 May 2024 20:11:22 +0000 Subject: [PATCH 11/17] Fixing unit tests with different model that doesn't require license --- test/pytest/test_ipex_serving.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/test/pytest/test_ipex_serving.py b/test/pytest/test_ipex_serving.py index 8fd7168e26..4bf582301b 100644 --- a/test/pytest/test_ipex_serving.py +++ b/test/pytest/test_ipex_serving.py @@ -108,7 +108,7 @@ def start_torchserve(ts_config_file): ) -LLAMA_DEFAULT_CONFIG = f""" +DEFAULT_CONFIG = f""" minWorkers: 1 maxWorkers: 1 responseTimeout: 1500 @@ -116,7 +116,7 @@ def start_torchserve(ts_config_file): maxBatchDelay: 100 handler: - model_name: "meta-llama/Llama-2-7b-hf" + model_name: "baichuan-inc/Baichuan2-7B-Chat" clear_cache_dir: true quantized_model_path: "best_model.pt" example_inputs_mode: "MASK_KV_POS" @@ -135,7 +135,7 @@ def start_torchserve(ts_config_file): """ -LLAMA_CONFIG_WOQ = f""" +CONFIG_WOQ = f""" minWorkers: 1 maxWorkers: 1 responseTimeout: 1500 @@ -143,7 +143,7 @@ def start_torchserve(ts_config_file): maxBatchDelay: 100 handler: - model_name: "meta-llama/Llama-2-7b-hf" + model_name: "baichuan-inc/Baichuan2-7B-Chat" clear_cache_dir: true quantized_model_path: "best_model.pt" example_inputs_mode: "MASK_KV_POS" @@ -168,7 +168,7 @@ def start_torchserve(ts_config_file): greedy: true """ -LLAMA_CONFIG_SQ = f""" +CONFIG_SQ = f""" minWorkers: 1 maxWorkers: 1 responseTimeout: 1500 @@ -176,7 +176,7 @@ def start_torchserve(ts_config_file): maxBatchDelay: 100 handler: - model_name: "meta-llama/Llama-2-7b-hf" + model_name: "baichuan-inc/Baichuan2-7B-Chat" clear_cache_dir: true quantized_model_path: "best_model.pt" example_inputs_mode: "MASK_KV_POS" @@ -215,7 +215,7 @@ def test_handler_default_pytorch(work_dir, model_archiver): test_utils.torchserve_cleanup() # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): model_config_yaml = work_dir / "model-config.yaml" - model_config_yaml.write_text(LLAMA_DEFAULT_CONFIG) + model_config_yaml.write_text(DEFAULT_CONFIG) # Create mar file model_name = "llama2_no_ipex" @@ -233,7 +233,7 @@ def test_handler_default_pytorch(work_dir, model_archiver): # query model info model_url = f"{MANAGEMENT_API}/models/{model_name}" response = requests.get(model_url) - assert response.status_code == 200, "The Model failed the with default Pytorch" + assert response.status_code == 200, "The default PyTorch Model failed to load" # send prompts to the model model_url = f"{INFERENCE_API}/predictions/{model_name}" @@ -252,7 +252,7 @@ def test_handler_ipex_bf16(work_dir, model_archiver): test_utils.torchserve_cleanup() # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): model_config_yaml = work_dir / "model-config.yaml" - model_config_yaml.write_text(LLAMA_DEFAULT_CONFIG) + model_config_yaml.write_text(DEFAULT_CONFIG) # Create mar file model_name = "llama2_ipex_bf16" @@ -270,7 +270,7 @@ def test_handler_ipex_bf16(work_dir, model_archiver): # query model info model_url = f"{MANAGEMENT_API}/models/{model_name}" response = requests.get(model_url) - assert response.status_code == 200, "The Model failed the with default Pytorch" + assert response.status_code == 200, "The IPEX bFloat16 model failed to initialize" # send prompts to the model model_url = f"{INFERENCE_API}/predictions/{model_name}" @@ -289,7 +289,7 @@ def test_handler_ipex_int8_woq(work_dir, model_archiver): test_utils.torchserve_cleanup() # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): model_config_yaml = work_dir / "model-config.yaml" - model_config_yaml.write_text(LLAMA_CONFIG_WOQ) + model_config_yaml.write_text(CONFIG_WOQ) # Create mar file model_name = "llama2_ipex_int8_woq" @@ -307,7 +307,7 @@ def test_handler_ipex_int8_woq(work_dir, model_archiver): # query model info model_url = f"{MANAGEMENT_API}/models/{model_name}" response = requests.get(model_url) - assert response.status_code == 200, "The Model failed the with default Pytorch" + assert response.status_code == 200, "The IPEX weight-only quantization Model failed to initialize" # send prompts to the model model_url = f"{INFERENCE_API}/predictions/{model_name}" @@ -326,7 +326,7 @@ def test_handler_ipex_int8_sq(work_dir, model_archiver): test_utils.torchserve_cleanup() # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): model_config_yaml = work_dir / "model-config.yaml" - model_config_yaml.write_text(LLAMA_CONFIG_SQ) + model_config_yaml.write_text(CONFIG_SQ) # Create mar file model_name = "llama2_ipex_int8_sq" @@ -344,7 +344,7 @@ def test_handler_ipex_int8_sq(work_dir, model_archiver): # query model info model_url = f"{MANAGEMENT_API}/models/{model_name}" response = requests.get(model_url) - assert response.status_code == 200, "The Model failed the with default Pytorch" + assert response.status_code == 200, "The IPEX smoothquant quantized Model failed to load" # send prompts to the model model_url = f"{INFERENCE_API}/predictions/{model_name}" From 3586115d120ba242e7466adab21733662a7332c3 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Wed, 15 May 2024 20:28:36 +0000 Subject: [PATCH 12/17] Fix lint error --- .../large_models/ipex_llm_int8/llm_handler.py | 415 +++++++++++------- 1 file changed, 264 insertions(+), 151 deletions(-) diff --git a/examples/large_models/ipex_llm_int8/llm_handler.py b/examples/large_models/ipex_llm_int8/llm_handler.py index 36ff76ce5e..90fbc93546 100644 --- a/examples/large_models/ipex_llm_int8/llm_handler.py +++ b/examples/large_models/ipex_llm_int8/llm_handler.py @@ -1,14 +1,17 @@ -import os import logging +import os import re import torch import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig -from transformers import T5ForConditionalGeneration - from datasets import load_dataset from torch.utils.data import DataLoader +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + T5ForConditionalGeneration, +) from ts.context import Context from ts.torch_handler.base_handler import BaseHandler @@ -21,6 +24,7 @@ if os.environ.get("TS_IPEX_ENABLE", "false") == "true": try: import intel_extension_for_pytorch as ipex + try: ipex._C.disable_jit_linear_repack() torch._C._jit_set_texpr_fuser_enabled(False) @@ -31,62 +35,95 @@ logger.info("IPEX version %s", ipex.__version__) except ImportError as error: - logger.warning("IPEX is enabled but intel-extension-for-pytorch cannot be imported. Proceeding without IPEX") + logger.warning( + "IPEX is enabled but intel-extension-for-pytorch cannot be imported. Proceeding without IPEX" + ) IPEX_ENABLE = False else: - logger.warning("IPEX is not enabled, consider enabling it for best performance on Intel hardware") + logger.warning( + "IPEX is not enabled, consider enabling it for best performance on Intel hardware" + ) -class IpexLLMHandler(BaseHandler): +class IpexLLMHandler(BaseHandler): def __init__(self): super(IpexLLMHandler, self).__init__() - + # for streaming the generated texts back to client self.output_streamer = None - def initialize(self, ctx: Context): model_name = ctx.model_yaml_config["handler"]["model_name"] # path to quantized model, if we are quantizing on the fly, we'll use this path to save the model - self.clear_cache_dir = ctx.model_yaml_config["handler"].get("clear_cache_dir", False) - self.quantized_model_path = ctx.model_yaml_config["handler"].get("quantized_model_path", "best_model.pt") - self.example_inputs_mode = ctx.model_yaml_config["handler"].get("example_inputs_mode", "MASK_KV_POS") - self.to_channels_last = ctx.model_yaml_config["handler"].get("to_channels_last", False) - + self.clear_cache_dir = ctx.model_yaml_config["handler"].get( + "clear_cache_dir", False + ) + self.quantized_model_path = ctx.model_yaml_config["handler"].get( + "quantized_model_path", "best_model.pt" + ) + self.example_inputs_mode = ctx.model_yaml_config["handler"].get( + "example_inputs_mode", "MASK_KV_POS" + ) + self.to_channels_last = ctx.model_yaml_config["handler"].get( + "to_channels_last", False + ) + # generation params self.batch_size = int(ctx.model_yaml_config["handler"].get("batch_size", "1")) - self.input_tokens = int(ctx.model_yaml_config["handler"].get("input_tokens", "1024")) - self.max_new_tokens = int(ctx.model_yaml_config["handler"].get("max_new_tokens", "128")) + self.input_tokens = int( + ctx.model_yaml_config["handler"].get("input_tokens", "1024") + ) + self.max_new_tokens = int( + ctx.model_yaml_config["handler"].get("max_new_tokens", "128") + ) # enable auto mix precision - self.auto_mixed_precision = ctx.model_yaml_config["handler"].get("auto_mixed_precision", True) - - # use int8 bf16 mix - self.quant_with_amp = ctx.model_yaml_config["handler"].get("quant_with_amp", True) - - # WoQ related optimization params - self.ipex_weight_only_quantization = ctx.model_yaml_config["handler"].get("ipex_weight_only_quantization", False) + self.auto_mixed_precision = ctx.model_yaml_config["handler"].get( + "auto_mixed_precision", True + ) + + # use int8 bf16 mix + self.quant_with_amp = ctx.model_yaml_config["handler"].get( + "quant_with_amp", True + ) + + # WoQ related optimization params + self.ipex_weight_only_quantization = ctx.model_yaml_config["handler"].get( + "ipex_weight_only_quantization", False + ) if self.ipex_weight_only_quantization: self.woq_dtype = ctx.model_yaml_config["handler"].get("woq_dtype", "INT8") self.lowp_mode = ctx.model_yaml_config["handler"].get("lowp_mode", "BF16") - self.act_quant_mode = ctx.model_yaml_config["handler"].get("act_quant_mode", "PER_IC_BLOCK") # This is only relevant for INT4x2 quantization - self.group_size = int(ctx.model_yaml_config["handler"].get("group_size", "-1")) + self.act_quant_mode = ctx.model_yaml_config["handler"].get( + "act_quant_mode", "PER_IC_BLOCK" + ) # This is only relevant for INT4x2 quantization + self.group_size = int( + ctx.model_yaml_config["handler"].get("group_size", "-1") + ) - # SQ related optimization params - self.ipex_smooth_quantization = ctx.model_yaml_config["handler"].get("ipex_smooth_quantization", False) + # SQ related optimization params + self.ipex_smooth_quantization = ctx.model_yaml_config["handler"].get( + "ipex_smooth_quantization", False + ) if self.ipex_smooth_quantization: - self.num_calib_iters = int(ctx.model_yaml_config["handler"].get("num_calibration_iters", 32)) + self.num_calib_iters = int( + ctx.model_yaml_config["handler"].get("num_calibration_iters", 32) + ) self.alpha = float(ctx.model_yaml_config["handler"].get("alpha", 0.9)) - + # Keeping outside because we want to use it for tracing as well - self.calib_dataset = ctx.model_yaml_config["handler"].get("calibration_dataset", "NeelNanda/pile-10k") - self.calib_split = ctx.model_yaml_config["handler"].get("calibration_split", "train") + self.calib_dataset = ctx.model_yaml_config["handler"].get( + "calibration_dataset", "NeelNanda/pile-10k" + ) + self.calib_split = ctx.model_yaml_config["handler"].get( + "calibration_split", "train" + ) - # decoding parameters + # decoding parameters self.greedy = ctx.model_yaml_config["handler"].get("greedy", False) - # amp datatype - if self.quant_with_amp or self.auto_mixed_precision: + # amp datatype + if self.quant_with_amp or self.auto_mixed_precision: self.amp_enabled = True self.amp_dtype = torch.bfloat16 else: @@ -97,19 +134,21 @@ def initialize(self, ctx: Context): self.num_beams = 1 if self.greedy else 4 # donot use min number of tokens on demo mode, only use it on benchmark mode self.generate_kwargs = dict( - do_sample=False, - temperature=0.9, - num_beams=self.num_beams, + do_sample=False, + temperature=0.9, + num_beams=self.num_beams, max_new_tokens=self.max_new_tokens, - ) - - # device + ) + + # device device = torch.device("cpu") - - # model config - config = AutoConfig.from_pretrained(model_name, torchscript=True, trust_remote_code=True) - - # set up max context + + # model config + config = AutoConfig.from_pretrained( + model_name, torchscript=True, trust_remote_code=True + ) + + # set up max context if not hasattr(config, "text_max_length"): config.text_max_length = int(self.input_tokens) + int(self.max_new_tokens) if "mpt" in model_name and not hasattr(config, "max_seq_len"): @@ -118,31 +157,48 @@ def initialize(self, ctx: Context): # load model and tokenizer, # We need special provision for t5 because it's seq2seq model, and can not be loaded with AutoModelForCausalLM if re.search("t5", config.architectures[0], re.IGNORECASE): - self.user_model = T5ForConditionalGeneration.from_pretrained(model_name, config=config, low_cpu_mem_usage=True, torch_dtype=torch.float) + self.user_model = T5ForConditionalGeneration.from_pretrained( + model_name, + config=config, + low_cpu_mem_usage=True, + torch_dtype=torch.float, + ) input_ids = torch.ones(32).to(torch.long).unsqueeze(0) attention_mask = torch.ones_like(input_ids) - dummy_inputs = self.user_model.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask) + dummy_inputs = self.user_model.prepare_inputs_for_generation( + input_ids, attention_mask=attention_mask + ) if dummy_inputs.get("position_ids", None) is not None: self.example_inputs_mode = "MASK_KV_POS" - # we also need to update generation kwargs + # we also need to update generation kwargs self.generate_kwargs["max_length"] = self.generate_kwargs["max_new_tokens"] self.generate_kwargs.pop("max_new_tokens") else: - self.user_model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float) + self.user_model = AutoModelForCausalLM.from_pretrained( + model_name, + config=config, + trust_remote_code=True, + low_cpu_mem_usage=True, + torch_dtype=torch.float, + ) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, trust_remote_code=True + ) logger.info("Data type of the model: %s", self.user_model.dtype) - + if self.to_channels_last: self.user_model = self.user_model.to(memory_format=torch.channels_last) self.user_model.eval() - # dummy past key value - self.beam_idx_tmp = torch.zeros((2048, int(self.batch_size * self.num_beams)), dtype=torch.long).contiguous() + self.beam_idx_tmp = torch.zeros( + (2048, int(self.batch_size * self.num_beams)), dtype=torch.long + ).contiguous() + def _get_target_nums(names): for n in names: if hasattr(self.user_model.config, n): @@ -167,19 +223,25 @@ def _get_target_nums(names): for i in range(n_layers) ] - logger.info(f"num_attention_heads: {n_heads}, num_hidden_layers: {n_layers}, hidden size: {hidden_size}, head_dim: {head_dim}") - + logger.info( + f"num_attention_heads: {n_heads}, num_hidden_layers: {n_layers}, hidden size: {hidden_size}, head_dim: {head_dim}" + ) + logger.info("Preparing the dataset for calibration and tracing") + class Evaluator: - def __init__(self, - user_model, - example_inputs_mode, - global_past_key_value, - dataset, tokenizer, - batch_size=1, - num_beams=1, - pad_val=1, - pad_max=512): + def __init__( + self, + user_model, + example_inputs_mode, + global_past_key_value, + dataset, + tokenizer, + batch_size=1, + num_beams=1, + pad_val=1, + pad_max=512, + ): self.user_model = user_model self.example_inputs_mode = example_inputs_mode self.global_past_key_value = global_past_key_value @@ -188,11 +250,10 @@ def __init__(self, self.batch_size = batch_size self.num_beams = num_beams - self.pad_val = pad_val - self.pad_max = pad_max + self.pad_max = pad_max self.dataset = self.dataset.map(self.tokenize_function, batched=True) - self.dataset.set_format(type="torch", columns=["input_ids"]) + self.dataset.set_format(type="torch", columns=["input_ids"]) @torch.no_grad() def tokenize_function(self, examples): @@ -204,7 +265,6 @@ def tokenize_function(self, examples): example = self.tokenizer(examples["code"]) return example - @torch.no_grad() def collate_batch(self, batch): position_ids_padded = [] @@ -252,16 +312,21 @@ def collate_batch(self, batch): model_kwargs = { "attention_mask": torch.vstack(attention_mask_padded), } - model_kwargs = self.user_model._prepare_encoder_decoder_kwargs_for_generation( - torch.vstack(input_ids_padded), model_kwargs, "input_ids" + model_kwargs = ( + self.user_model._prepare_encoder_decoder_kwargs_for_generation( + torch.vstack(input_ids_padded), model_kwargs, "input_ids" + ) ) - input_ids, example_inputs = self.user_model._expand_inputs_for_generation( + ( + input_ids, + example_inputs, + ) = self.user_model._expand_inputs_for_generation( input_ids=torch.vstack(input_ids_padded), expand_size=self.num_beams, is_encoder_decoder=True, **model_kwargs, ) - + # need to recompute these def _get_target_nums(names): for n in names: @@ -270,18 +335,32 @@ def _get_target_nums(names): logger.error(f"Not found target {names[0]}") exit(1) - num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"] - num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"] + num_heads_names = [ + "num_attention_heads", + "n_head", + "num_heads", + "n_heads", + ] + num_layers_names = [ + "num_hidden_layers", + "n_layer", + "num_layers", + "n_layers", + ] hidden_size_names = ["hidden_size", "n_embd"] n_heads = _get_target_nums(num_heads_names) n_layers = _get_target_nums(num_layers_names) hidden_size = _get_target_nums(hidden_size_names) head_dim = int(hidden_size / n_heads) - + # lets get the inputs - beam_idx_tmp = torch.zeros((2048, int(self.batch_size * self.num_beams)), dtype=torch.long).contiguous() + beam_idx_tmp = torch.zeros( + (2048, int(self.batch_size * self.num_beams)), dtype=torch.long + ).contiguous() input_bs = int(self.batch_size * self.num_beams) - last_hidden_state = example_inputs["encoder_outputs"]["last_hidden_state"] + last_hidden_state = example_inputs["encoder_outputs"][ + "last_hidden_state" + ] global_past_key_value = tuple( [ ( @@ -290,15 +369,25 @@ def _get_target_nums(names): torch.zeros([1, n_heads, 1, head_dim]).contiguous(), beam_idx_tmp, torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - self.user_model.decoder.block[i].layer[1].EncDecAttention.k(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), - self.user_model.decoder.block[i].layer[1].EncDecAttention.v(last_hidden_state).view(input_bs, -1, n_heads, head_dim).transpose(0, 1), + self.user_model.decoder.block[i] + .layer[1] + .EncDecAttention.k(last_hidden_state) + .view(input_bs, -1, n_heads, head_dim) + .transpose(0, 1), + self.user_model.decoder.block[i] + .layer[1] + .EncDecAttention.v(last_hidden_state) + .view(input_bs, -1, n_heads, head_dim) + .transpose(0, 1), beam_idx_tmp, ) for i in range(n_layers) ] ) - decoder_input_ids = (torch.zeros(input_bs).to(torch.long).unsqueeze(1)) + decoder_input_ids = ( + torch.zeros(input_bs).to(torch.long).unsqueeze(1) + ) model_inputs = ( decoder_input_ids, torch.vstack(attention_mask_padded), @@ -306,28 +395,32 @@ def _get_target_nums(names): (last_hidden_state,), ) else: - raise RuntimeError("Your model does not match existing example inputs used in ipex smooth quant, exiting...") - - # Some models require extra inputs - if re.search("chatglm", self.user_model.config.architectures[0], re.IGNORECASE): + raise RuntimeError( + "Your model does not match existing example inputs used in ipex smooth quant, exiting..." + ) + + # Some models require extra inputs + if re.search( + "chatglm", self.user_model.config.architectures[0], re.IGNORECASE + ): extra_inputs = (torch.tensor(True),) model_inputs = model_inputs + extra_inputs return (model_inputs, last_ind) - - calib_dataset = load_dataset(self.calib_dataset, split=self.calib_split) logger.info(f"Dataset loaded: {calib_dataset}") calib_evaluator = Evaluator( self.user_model, self.example_inputs_mode, self.global_past_key_value, - calib_dataset, - self.tokenizer, - batch_size=self.batch_size, - num_beams = self.num_beams, - pad_max = int(self.input_tokens) if re.search("t5", config.architectures[0], re.IGNORECASE) else 512 + calib_dataset, + self.tokenizer, + batch_size=self.batch_size, + num_beams=self.num_beams, + pad_max=int(self.input_tokens) + if re.search("t5", config.architectures[0], re.IGNORECASE) + else 512, ) logger.info(f"Evaluator built: {calib_evaluator}") @@ -341,40 +434,45 @@ def _get_target_nums(names): if IPEX_ENABLE: """ - Ipex is enabled, we'll use - (1) weight only quantization if ipex_weight_only_quantization is enabled - (2) ipex smooth quantization if ipex_smooth_quantization is enabled - (3) ipex bfloat16 optimization if neither is quantization is enabled - (4) throws error if both 1 and 2 are enabled + Ipex is enabled, we'll use + (1) weight only quantization if ipex_weight_only_quantization is enabled + (2) ipex smooth quantization if ipex_smooth_quantization is enabled + (3) ipex bfloat16 optimization if neither is quantization is enabled + (4) throws error if both 1 and 2 are enabled """ - + def trace_and_export(model): example_inputs = self.get_example_inputs() with torch.no_grad(), torch.cpu.amp.autocast( - enabled=self.amp_enabled, - dtype=self.amp_dtype + enabled=self.amp_enabled, dtype=self.amp_dtype ): - self_jit = torch.jit.trace(model.eval(), example_inputs, strict=False, check_trace=False) + self_jit = torch.jit.trace( + model.eval(), example_inputs, strict=False, check_trace=False + ) self_jit = torch.jit.freeze(self_jit.eval()) self_jit.save(self.quantized_model_path) - + if self.ipex_smooth_quantization and self.ipex_weight_only_quantization: logger.error("Can't enable both SQ and WoQ, enable only one of them") exit(1) - - # Clear the cache dir if needed + + # Clear the cache dir if needed if self.clear_cache_dir and os.path.exists(self.quantized_model_path): os.remove(self.quantized_model_path) if os.path.exists(self.quantized_model_path): - # this skips all the optimizations and goes to end where we load the model - logger.info("A previously quantized model is loaded, if you want to re-quantize the model, enable clear_cache_dir on model config file") + # this skips all the optimizations and goes to end where we load the model + logger.info( + "A previously quantized model is loaded, if you want to re-quantize the model, enable clear_cache_dir on model config file" + ) - # lets implement the WOQ + # lets implement the WOQ elif self.ipex_weight_only_quantization: - weight_dtype = torch.quint4x2 if self.woq_dtype == "INT4" else torch.qint8 + weight_dtype = ( + torch.quint4x2 if self.woq_dtype == "INT4" else torch.qint8 + ) if self.lowp_mode == "INT8": lowp_mode = ipex.quantization.WoqLowpMode.INT8 @@ -400,8 +498,8 @@ def trace_and_export(model): act_quant_mode=act_quant_mode_dict[self.act_quant_mode], group_size=self.group_size, ) - - # low precision checkpoint can be loaded, but we're considering there isn't any + + # low precision checkpoint can be loaded, but we're considering there isn't any low_precision_checkpoint = None self.user_model = ipex.llm.optimize( self.user_model.eval(), @@ -411,16 +509,23 @@ def trace_and_export(model): low_precision_checkpoint=low_precision_checkpoint, deployment_mode=False, ) - logger.info("The model conversion completed, now tracing the quantized model") - + logger.info( + "The model conversion completed, now tracing the quantized model" + ) + trace_and_export(self.user_model) - logger.info("The IPEX Weight only quantization has been completed successfully") + logger.info( + "The IPEX Weight only quantization has been completed successfully" + ) elif self.ipex_smooth_quantization: - from intel_extension_for_pytorch.quantization import prepare, convert + from intel_extension_for_pytorch.quantization import convert, prepare + example_inputs = self.get_example_inputs() - qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=self.alpha) + qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping( + alpha=self.alpha + ) user_model = ipex.llm.optimize( self.user_model.eval(), dtype=self.amp_dtype, @@ -429,10 +534,14 @@ def trace_and_export(model): deployment_mode=False, ) - prepared_model = prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True) + prepared_model = prepare( + user_model.eval(), + qconfig, + example_inputs=example_inputs, + inplace=True, + ) logger.info("Model prepared for quantization, observers inserted") - for i, (model_inputs, last_ind) in enumerate(self.calib_dataloader): if i == self.num_calib_iters: break @@ -443,7 +552,7 @@ def trace_and_export(model): logger.info("Model converted successfully, exporting the trace") trace_and_export(converted_model) - + logger.info("IPEX Smooth Quantization has completed successfully") else: @@ -451,14 +560,14 @@ def trace_and_export(model): example_inputs = self.get_example_inputs() self.user_model = ipex.llm.optimize( self.user_model.eval(), - dtype = self.amp_dtype, + dtype=self.amp_dtype, inplace=True, deployment_mode=False, ) - + trace_and_export(self.user_model) logger.info("IPEX bf16 optimization is applied successfully") - + logger.info("Loading the IPEX quantized model") try: self_jit = torch.jit.load(self.quantized_model_path) @@ -466,25 +575,28 @@ def trace_and_export(model): except Exception as e: logger.error("Error: loading the quantized model failed.", e) exit(0) - + setattr(self.user_model, "trace_graph", self_jit) - logger.info(f"Successfully loaded the Model {model_name} with Intel® Extension for PyTorch*") + logger.info( + f"Successfully loaded the Model {model_name} with Intel® Extension for PyTorch*" + ) else: # No optimization is applied, but if amx is enabled, it'll be applied during generation routine - logger.warning("No IPEX optimization is applied, Pytorch default autocast will be applied if enabled") - - + logger.warning( + "No IPEX optimization is applied, Pytorch default autocast will be applied if enabled" + ) + # set PAD token if self.tokenizer.pad_token is None: - if re.search("qwen", self.user_model.config.architectures[0], re.IGNORECASE): - self.tokenizer.pad_token = '<|endoftext|>' + if re.search( + "qwen", self.user_model.config.architectures[0], re.IGNORECASE + ): + self.tokenizer.pad_token = "<|endoftext|>" else: - self.tokenizer.pad_token=self.tokenizer.eos_token + self.tokenizer.pad_token = self.tokenizer.eos_token - - - # we are going to use data collator we built to generate example input + # we are going to use data collator we built to generate example input def get_example_inputs(self): (model_inputs, last_ind) = next(iter(self.calib_dataloader)) return model_inputs @@ -498,20 +610,18 @@ def preprocess(self, requests): input_text = data.get("body") if isinstance(input_text, (bytes, bytearray)): input_text = input_text.decode("utf-8") - + with torch.inference_mode(), torch.no_grad(), torch.autocast( - device_type="cpu", - enabled=self.amp_enabled, - dtype=self.amp_dtype + device_type="cpu", enabled=self.amp_enabled, dtype=self.amp_dtype ): inputs = self.tokenizer( - input_text, - pad_to_max_length=True, - add_special_tokens=True, - return_tensors="pt", - #max_length=int(self.max_length), - ) - + input_text, + pad_to_max_length=True, + add_special_tokens=True, + return_tensors="pt", + # max_length=int(self.max_length), + ) + input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] # making a batch out of the recieved requests @@ -521,26 +631,29 @@ def preprocess(self, requests): attention_mask_batch = attention_mask else: input_ids_batch = torch.cat((input_ids_batch, input_ids), 0) - attention_mask_batch = torch.cat((attention_mask_batch, attention_mask), 0) + attention_mask_batch = torch.cat( + (attention_mask_batch, attention_mask), 0 + ) return (input_ids_batch, attention_mask_batch) - + def inference(self, input_batch): input_ids_batch, attention_mask_batch = input_batch inferences = [] # total_list = [] - + with torch.inference_mode(), torch.no_grad(), torch.autocast( - device_type="cpu", - enabled=self.amp_enabled, - dtype=self.amp_dtype + device_type="cpu", enabled=self.amp_enabled, dtype=self.amp_dtype ): - outputs = self.user_model.generate(input_ids_batch, attention_mask=attention_mask_batch, **self.generate_kwargs) + outputs = self.user_model.generate( + input_ids_batch, + attention_mask=attention_mask_batch, + **self.generate_kwargs, + ) inferences = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) - #for i, x in enumerate(outputs): + # for i, x in enumerate(outputs): # inferences.append(self.tokenizer.decode(outputs[i], skip_special_tokens=True)) return inferences def postprocess(self, inference_output): return inference_output - From 02f885f97b2c221cceda67f8a8a680b67863dc4c Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Wed, 15 May 2024 20:32:17 +0000 Subject: [PATCH 13/17] Fix lint error in test --- test/pytest/test_ipex_serving.py | 177 ++++++++++++++++--------------- 1 file changed, 92 insertions(+), 85 deletions(-) diff --git a/test/pytest/test_ipex_serving.py b/test/pytest/test_ipex_serving.py index 4bf582301b..d29505e8b5 100644 --- a/test/pytest/test_ipex_serving.py +++ b/test/pytest/test_ipex_serving.py @@ -1,40 +1,34 @@ -import os -import sys import json -from pathlib import Path +import logging +import os import shutil import subprocess -import yaml +import sys +from pathlib import Path +from unittest.mock import patch import pytest import requests import test_utils -import torch -#from test_handler import run_inference_using_url_with_data - -from unittest.mock import patch -from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext - -from string import Template -import logging - from model_archiver.model_archiver_config import ModelArchiverConfig -from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext - REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") snapshot_file_ipex = os.path.join(REPO_ROOT, "test/config_ipex.properties") default_ts_config = os.path.join(REPO_ROOT, "test/config_ts.properties") -prompt_file = os.path.join(REPO_ROOT, "examples/large_models/ipex_llm_int8/sample_text_0.txt") +prompt_file = os.path.join( + REPO_ROOT, "examples/large_models/ipex_llm_int8/sample_text_0.txt" +) -#CURR_FILE_PATH = Path(__file__).parent +# CURR_FILE_PATH = Path(__file__).parent HANDLER_PATH = os.path.join(REPO_ROOT, "examples/large_models/ipex_llm_int8/") sys.path.append(HANDLER_PATH) logger = logging.Logger(__name__) -PROMPTS = ["The capital of France is ",] +PROMPTS = [ + "The capital of France is ", +] MANAGEMENT_API = "http://localhost:8081" INFERENCE_API = "http://localhost:8080" @@ -55,6 +49,7 @@ ipex_xeon_run_available = xeon_run_cpu_available and ipex_available + @pytest.fixture(scope="module") def model_name(): yield "llama2" @@ -64,13 +59,14 @@ def model_name(): def work_dir(tmp_path_factory, model_name): return Path(tmp_path_factory.mktemp(model_name)) + # @pytest.fixture(scope="module", name="mar_file_path") def create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): mar_file_path = work_dir.joinpath(model_name + ".mar") handler_file = os.path.join(HANDLER_PATH, "llm_handler.py") - assert(Path(handler_file).exists()) - + assert Path(handler_file).exists() + config = ModelArchiverConfig( model_name=model_name, version="1.0", @@ -93,18 +89,17 @@ def create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file return mar_file_path.as_posix() + def run_inference_with_prompt(prompt_file, model_name): model_url = f"{INFERENCE_API}/predictions/{model_name}" response = run_inference_using_url_with_data(model_url, prompt_file) return response + def start_torchserve(ts_config_file): - - # start the torchserve + # start the torchserve test_utils.start_torchserve( - model_store=test_utils.MODEL_STORE, - snapshot_file=ts_config_file, - gen_mar=False + model_store=test_utils.MODEL_STORE, snapshot_file=ts_config_file, gen_mar=False ) @@ -114,22 +109,22 @@ def start_torchserve(ts_config_file): responseTimeout: 1500 batchSize: 4 maxBatchDelay: 100 - + handler: model_name: "baichuan-inc/Baichuan2-7B-Chat" clear_cache_dir: true quantized_model_path: "best_model.pt" example_inputs_mode: "MASK_KV_POS" to_channels_last: false - + # generation params batch_size: 1 # this batch size is mostly used for calibration, you can leave it as 1 input_tokens: 1024 max_new_tokens: 128 - + # Use INT8 bf16 mix quant_with_amp: true - + # decoding technique greedy: true @@ -141,29 +136,29 @@ def start_torchserve(ts_config_file): responseTimeout: 1500 batchSize: 4 maxBatchDelay: 100 - + handler: model_name: "baichuan-inc/Baichuan2-7B-Chat" clear_cache_dir: true quantized_model_path: "best_model.pt" example_inputs_mode: "MASK_KV_POS" to_channels_last: false - + # generation params batch_size: 1 input_tokens: 1024 max_new_tokens: 128 - + # Use INT8 bf16 mix quant_with_amp: true - + # Woq params ipex_weight_only_quantization: true woq_dtype: "INT8" lowp_mode: "BF16" act_quant_mode: "PER_IC_BLOCK" group_size: -1 - + # decoding technique greedy: true """ @@ -174,42 +169,34 @@ def start_torchserve(ts_config_file): responseTimeout: 1500 batchSize: 4 maxBatchDelay: 100 - + handler: model_name: "baichuan-inc/Baichuan2-7B-Chat" clear_cache_dir: true quantized_model_path: "best_model.pt" example_inputs_mode: "MASK_KV_POS" to_channels_last: false - + # generation params batch_size: 1 input_tokens: 1024 max_new_tokens: 128 - + # use bf16-int8 mix quant_with_amp: true - + # SQ quantization params ipex_smooth_quantization: true calibration_dataset: "NeelNanda/pile-10k" calibration_split: "train" num_calibration_iters: 32 alpha: 0.9 - + # decoding technique greedy: true """ -""" -outline of the tests: - 1. edit the config - 2. create mar file - 3. start torchserve - 4. test connection - 5. test response correctness -""" def test_handler_default_pytorch(work_dir, model_archiver): test_utils.torchserve_cleanup() @@ -217,9 +204,11 @@ def test_handler_default_pytorch(work_dir, model_archiver): model_config_yaml = work_dir / "model-config.yaml" model_config_yaml.write_text(DEFAULT_CONFIG) - # Create mar file + # Create mar file model_name = "llama2_no_ipex" - mar_file_path = create_mar_file(work_dir, model_archiver, model_name, model_config_yaml) + mar_file_path = create_mar_file( + work_dir, model_archiver, model_name, model_config_yaml + ) os.makedirs(os.path.dirname(test_utils.MODEL_STORE), exist_ok=True) shutil.move(mar_file_path, test_utils.MODEL_STORE) @@ -229,21 +218,23 @@ def test_handler_default_pytorch(work_dir, model_archiver): # load the model model_url = f"{MANAGEMENT_API}/models?url={model_name}.mar" requests.post(model_url) - + # query model info model_url = f"{MANAGEMENT_API}/models/{model_name}" response = requests.get(model_url) assert response.status_code == 200, "The default PyTorch Model failed to load" - + # send prompts to the model model_url = f"{INFERENCE_API}/predictions/{model_name}" - response = requests.post(url=model_url, - data=json.dumps(PROMPTS[0],), - ) - + response = requests.post( + url=model_url, + data=json.dumps( + PROMPTS[0], + ), + ) + assert response.status_code == 200, "The model failed to generate text from prompt!" assert "Paris" in response.text, "The response doesn't seem to be correct!" - test_utils.torchserve_cleanup() @@ -254,9 +245,11 @@ def test_handler_ipex_bf16(work_dir, model_archiver): model_config_yaml = work_dir / "model-config.yaml" model_config_yaml.write_text(DEFAULT_CONFIG) - # Create mar file + # Create mar file model_name = "llama2_ipex_bf16" - mar_file_path = create_mar_file(work_dir, model_archiver, model_name, model_config_yaml) + mar_file_path = create_mar_file( + work_dir, model_archiver, model_name, model_config_yaml + ) os.makedirs(os.path.dirname(test_utils.MODEL_STORE), exist_ok=True) shutil.move(mar_file_path, test_utils.MODEL_STORE) @@ -266,21 +259,23 @@ def test_handler_ipex_bf16(work_dir, model_archiver): # load the model model_url = f"{MANAGEMENT_API}/models?url={model_name}.mar" requests.post(model_url) - + # query model info model_url = f"{MANAGEMENT_API}/models/{model_name}" response = requests.get(model_url) assert response.status_code == 200, "The IPEX bFloat16 model failed to initialize" - + # send prompts to the model model_url = f"{INFERENCE_API}/predictions/{model_name}" - response = requests.post(url=model_url, - data=json.dumps(PROMPTS[0],), - ) - + response = requests.post( + url=model_url, + data=json.dumps( + PROMPTS[0], + ), + ) + assert response.status_code == 200, "The model failed to generate text from prompt!" assert "Paris" in response.text, "The response doesn't seem to be correct!" - test_utils.torchserve_cleanup() @@ -291,9 +286,11 @@ def test_handler_ipex_int8_woq(work_dir, model_archiver): model_config_yaml = work_dir / "model-config.yaml" model_config_yaml.write_text(CONFIG_WOQ) - # Create mar file + # Create mar file model_name = "llama2_ipex_int8_woq" - mar_file_path = create_mar_file(work_dir, model_archiver, model_name, model_config_yaml) + mar_file_path = create_mar_file( + work_dir, model_archiver, model_name, model_config_yaml + ) os.makedirs(os.path.dirname(test_utils.MODEL_STORE), exist_ok=True) shutil.move(mar_file_path, test_utils.MODEL_STORE) @@ -303,21 +300,25 @@ def test_handler_ipex_int8_woq(work_dir, model_archiver): # load the model model_url = f"{MANAGEMENT_API}/models?url={model_name}.mar" requests.post(model_url) - + # query model info model_url = f"{MANAGEMENT_API}/models/{model_name}" response = requests.get(model_url) - assert response.status_code == 200, "The IPEX weight-only quantization Model failed to initialize" - + assert ( + response.status_code == 200 + ), "The IPEX weight-only quantization Model failed to initialize" + # send prompts to the model model_url = f"{INFERENCE_API}/predictions/{model_name}" - response = requests.post(url=model_url, - data=json.dumps(PROMPTS[0],), - ) - + response = requests.post( + url=model_url, + data=json.dumps( + PROMPTS[0], + ), + ) + assert response.status_code == 200, "The model failed to generate text from prompt!" assert "Paris" in response.text, "The response doesn't seem to be correct!" - test_utils.torchserve_cleanup() @@ -328,9 +329,11 @@ def test_handler_ipex_int8_sq(work_dir, model_archiver): model_config_yaml = work_dir / "model-config.yaml" model_config_yaml.write_text(CONFIG_SQ) - # Create mar file + # Create mar file model_name = "llama2_ipex_int8_sq" - mar_file_path = create_mar_file(work_dir, model_archiver, model_name, model_config_yaml) + mar_file_path = create_mar_file( + work_dir, model_archiver, model_name, model_config_yaml + ) os.makedirs(os.path.dirname(test_utils.MODEL_STORE), exist_ok=True) shutil.move(mar_file_path, test_utils.MODEL_STORE) @@ -340,20 +343,24 @@ def test_handler_ipex_int8_sq(work_dir, model_archiver): # load the model model_url = f"{MANAGEMENT_API}/models?url={model_name}.mar" requests.post(model_url) - + # query model info model_url = f"{MANAGEMENT_API}/models/{model_name}" response = requests.get(model_url) - assert response.status_code == 200, "The IPEX smoothquant quantized Model failed to load" - + assert ( + response.status_code == 200 + ), "The IPEX smoothquant quantized Model failed to load" + # send prompts to the model model_url = f"{INFERENCE_API}/predictions/{model_name}" - response = requests.post(url=model_url, - data=json.dumps(PROMPTS[0],), - ) - + response = requests.post( + url=model_url, + data=json.dumps( + PROMPTS[0], + ), + ) + assert response.status_code == 200, "The model failed to generate text from prompt!" assert "Paris" in response.text, "The response doesn't seem to be correct!" - test_utils.torchserve_cleanup() From 07d8ca9988140a4dfc643dcd2afd64e543f5e2e9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 May 2024 21:38:40 +0000 Subject: [PATCH 14/17] Adding requirements.txt --- examples/large_models/ipex_llm_int8/requirements.txt | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 examples/large_models/ipex_llm_int8/requirements.txt diff --git a/examples/large_models/ipex_llm_int8/requirements.txt b/examples/large_models/ipex_llm_int8/requirements.txt new file mode 100644 index 0000000000..5a862f37af --- /dev/null +++ b/examples/large_models/ipex_llm_int8/requirements.txt @@ -0,0 +1,5 @@ +datasets==2.18.0 +intel_extension_for_pytorch==2.4.0+gite2d4be3 +torch==2.4.0.dev20240328+cpu +torchserve==0.10.0 +transformers==4.38.2 From b14d40278e39886717511e47853a259a9338fed2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 May 2024 22:24:49 +0000 Subject: [PATCH 15/17] adding datasets to the requirements --- requirements/developer.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/developer.txt b/requirements/developer.txt index 5387bf1de0..efdd44119b 100644 --- a/requirements/developer.txt +++ b/requirements/developer.txt @@ -20,3 +20,4 @@ onnxruntime==1.17.1 googleapis-common-protos onnx==1.16.0 orjson +datasets From 3cebfebe5a5d1326b0f33bca8b7df0b8648df892 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 May 2024 23:20:58 +0000 Subject: [PATCH 16/17] upgrading the ipex version to 2.3.0 to match that of pytorch --- requirements/developer.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/developer.txt b/requirements/developer.txt index efdd44119b..4aa2450d1b 100644 --- a/requirements/developer.txt +++ b/requirements/developer.txt @@ -15,7 +15,7 @@ pre-commit==3.3.2 twine==4.0.2 mypy==1.3.0 torchpippy==0.1.1 -intel_extension_for_pytorch==2.2.0; sys_platform != 'win32' and sys_platform != 'darwin' and platform_machine != 'aarch64' +intel_extension_for_pytorch==2.3.0; sys_platform != 'win32' and sys_platform != 'darwin' and platform_machine != 'aarch64' onnxruntime==1.17.1 googleapis-common-protos onnx==1.16.0 From c6ad7a6be68e67d4bc72b40a28b179df21ad1318 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 16 May 2024 04:05:26 +0000 Subject: [PATCH 17/17] Skipping ipex llm tests if accelerate is not present --- .../ipex_llm_int8/requirements.txt | 1 + test/pytest/test_ipex_serving.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/examples/large_models/ipex_llm_int8/requirements.txt b/examples/large_models/ipex_llm_int8/requirements.txt index 5a862f37af..864af8dc8d 100644 --- a/examples/large_models/ipex_llm_int8/requirements.txt +++ b/examples/large_models/ipex_llm_int8/requirements.txt @@ -3,3 +3,4 @@ intel_extension_for_pytorch==2.4.0+gite2d4be3 torch==2.4.0.dev20240328+cpu torchserve==0.10.0 transformers==4.38.2 +accelerate diff --git a/test/pytest/test_ipex_serving.py b/test/pytest/test_ipex_serving.py index d29505e8b5..405e05c614 100644 --- a/test/pytest/test_ipex_serving.py +++ b/test/pytest/test_ipex_serving.py @@ -12,6 +12,12 @@ import test_utils from model_archiver.model_archiver_config import ModelArchiverConfig +ACCELERATE_UNAVAILABLE = False +try: + import accelerate # nopycln: import +except ImportError: + ACCELERATE_UNAVAILABLE = True + REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") snapshot_file_ipex = os.path.join(REPO_ROOT, "test/config_ipex.properties") default_ts_config = os.path.join(REPO_ROOT, "test/config_ts.properties") @@ -198,6 +204,9 @@ def start_torchserve(ts_config_file): """ +@pytest.mark.skipif( + ACCELERATE_UNAVAILABLE, reason="HF accelerate library not available" +) def test_handler_default_pytorch(work_dir, model_archiver): test_utils.torchserve_cleanup() # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): @@ -239,6 +248,9 @@ def test_handler_default_pytorch(work_dir, model_archiver): test_utils.torchserve_cleanup() +@pytest.mark.skipif( + ACCELERATE_UNAVAILABLE, reason="HF accelerate library not available" +) def test_handler_ipex_bf16(work_dir, model_archiver): test_utils.torchserve_cleanup() # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): @@ -280,6 +292,9 @@ def test_handler_ipex_bf16(work_dir, model_archiver): test_utils.torchserve_cleanup() +@pytest.mark.skipif( + ACCELERATE_UNAVAILABLE, reason="HF accelerate library not available" +) def test_handler_ipex_int8_woq(work_dir, model_archiver): test_utils.torchserve_cleanup() # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file): @@ -323,6 +338,9 @@ def test_handler_ipex_int8_woq(work_dir, model_archiver): test_utils.torchserve_cleanup() +@pytest.mark.skipif( + ACCELERATE_UNAVAILABLE, reason="HF accelerate library not available" +) def test_handler_ipex_int8_sq(work_dir, model_archiver): test_utils.torchserve_cleanup() # create_mar_file(work_dir, model_archiver, model_name, model_config_yaml_file):