From 91511d2f22c8ccca6746f5c116f5e73bd8790d46 Mon Sep 17 00:00:00 2001 From: liuzhenwei <109187816+zhenwei-intel@users.noreply.github.com> Date: Wed, 15 Nov 2023 10:56:07 +0800 Subject: [PATCH] [LLM Runtime] Refine Python API (#665) --- .../llm/runtime/graph/README.md | 82 +++++++++++--- .../llm/runtime/graph/__init__.py | 101 +++++++++++------- .../runtime/graph/application/main_pybind.cpp | 58 +++++++--- .../transformers/modeling/modeling_auto.py | 2 + .../transformers/utils/quantization_config.py | 4 + tests/test_llm_runtime.py | 12 +-- 6 files changed, 185 insertions(+), 74 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/README.md b/intel_extension_for_transformers/llm/runtime/graph/README.md index a98c35aed69..22a3ac5009c 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/README.md +++ b/intel_extension_for_transformers/llm/runtime/graph/README.md @@ -98,30 +98,43 @@ outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300, ctx_size https://github.com/intel/intel-extension-for-transformers/assets/109187816/1698dcda-c9ec-4f44-b159-f4e9d67ab15b +Argument description of WeightOnlyQuantConfig: +| Argument | Type | Description | +| -------------- | ---------- | ----------------------------------------------------------------------- | +| compute_dtype | String | Data type of Gemm computation: int8/bf16/fp32 (default: int8) | +| weight_dtype | String | Data type of quantized weight: int4/int8 (default int4) | +| alg | String | Quantization algorithm: sym/asym (default sym) | +| group_size | Int | Group size: Int (default: 32) | +| scale_dtype | String | Data type of scales: fp32/bf16 (dafault fp32) | +| use_ggml | Bool | Enable ggml for quantization and inference (default: False) | +| not_quant | Bool | Determine whether or not the model will be quantized. (default: False) | +| use_cache | Bool | Use local quantized model if file exists (default: False) | + Argument description of generate function: | Argument | Type | Description | | -------------- | ---------- | ----------------------------------------------------------------------- | | inputs | Lists[Int] | Input ids after tokenizer | -| streamer | Class | Streamer object that will be used to stream the generated sequences. (default: None) | | interactive | Bool | Interactive mode, use history commands when True (default: False) | +| n_keep | Int | Number of tokens to keep from the initial prompt (default: 0, -1 = all) | +| n_discard | Int | Number of tokens will be discarded (default: -1, -1 = half of tokens will be discarded) | +| shift_roped_k | Bool | Use ring-buffer and thus do not re-computing after reaching ctx_size (default: False) | | ignore_prompt | Bool | Generate outputs w/o prompt (default: False) | -| max_new_tokens | Int | Number of tokens to predict (default: -1, -1 = infinity) | | batch_size | Int | Batch size for prompt processing (default: 512) | | ctx_size | Int | Size of the prompt context (default: 512) | | seed | Int | NG seed (default: -1, use random seed for < 0) | | threads | Int | Number of threads to use during computation (default: 8) | -| repetition_penalty| Float | Penalize repeat sequence of tokens (default: 1.1, 1.0 = disabled) | -| num_beams | Int | Number of beams for beam_search (default: 1) | -| do_sample | Int | Whether or not to use sampling ; use greedy decoding otherwise. (default: False) | -| top_k | Int | Top-k sampling (default: 40, 0 = disabled) | -| top_p | Int | Top-p sampling (default: 0.95, 1.0 = disabled) | -| temperature | Float | Temperature (default: 0.8) | -| min_new_tokens | Int | The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. | -| length_penalty | Float | Exponential penalty to the length that is used with beam-based generation. | -| early_stopping | Bool | Controls the stopping condition for beam-based methods, like beam-search. | -| n_keep | Int | Number of tokens to keep from the initial prompt (default: 0, -1 = all) | -| n_discard | Int | Number of tokens will be discarded (default: -1, -1 = half of tokens will be discarded) | -| shift_roped_k | Bool | Use ring-buffer and thus do not re-computing after reaching ctx_size (default: False) | +| repetition_penalty| Float | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| num_beams | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| do_sample | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| top_k | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| top_p | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| temperature | Float | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| min_new_tokens | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| length_penalty | Float | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| early_stopping | Bool | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| max_new_tokens | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| streamer | Class | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | +| stopping_criteria | Class | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) | ### 3. Multi-Round Chat @@ -130,7 +143,8 @@ Chat with LLaMA2: from transformers import AutoTokenizer, TextStreamer from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig -model_name = "meta-llama/Llama-2-7b-chat-hf" # or local path to model +# Please change to local path to model, llama2 does not support online conversion, currently. +model_name = "meta-llama/Llama-2-7b-chat-hf" woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) streamer = TextStreamer(tokenizer) @@ -316,3 +330,41 @@ We support tensor parallelism strategy for distributed inference/training on mul ### 4. Contribution You can consider adding your own models via [graph developer document](./developer_document.md). + +### 5. Custom Stopping Criteria + +You can customize the stopping criteria according to your own needs by processing the input_ids to determine if text generation needs to be stopped. +Here is a simple example, which requires a minimum generation length of 80 tokens. Once the `min_length` is met, encountering a terminator `eos_token_id` will end the generation. + +```python +import torch +from typing import List +from transformers import StoppingCriteria, StoppingCriteriaList + +class StopOnTokens(StoppingCriteria): + def __init__(self, min_length: int, start_length: int, stop_token_id: List[int]): + self.min_length = min_length + self.start_length = start_length + self.stop_token_id = stop_token_id + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + if input_ids.shape[-1] - self.start_length > self.min_length: + for stop_id in self.stop_token_id: + if input_ids[0][input_ids.shape[-1] - 1] == stop_id: + return True + return False + +stopping_criteria = StoppingCriteriaList( + [ + StopOnTokens( + min_length=80, + start_length=inputs.shape[1], + stop_token_id=[tokenizer.eos_token_id], + ) + ] +) + +outputs = model.generate(inputs, streamer=streamer, stopping_criteria=stopping_criteria) +``` diff --git a/intel_extension_for_transformers/llm/runtime/graph/__init__.py b/intel_extension_for_transformers/llm/runtime/graph/__init__.py index 0265809bde9..b9f2e30b112 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/__init__.py +++ b/intel_extension_for_transformers/llm/runtime/graph/__init__.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from transformers import AutoConfig +from transformers import AutoConfig, AutoTokenizer from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model import torch model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"} @@ -61,44 +61,62 @@ def __import_package(self, model_name): raise TypeError("Unspported model type {}!".format(model_name)) self.module = cpp_model - def init(self, model_name, **kwargs): - config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - model_type = model_maps.get(config.model_type, config.model_type) - if model_type == "chatglm" and "chatglm2" in config._name_or_path: + def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs): + self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model_type = model_maps.get(self.config.model_type, self.config.model_type) + if model_type == "chatglm" and "chatglm2" in self.config._name_or_path: model_type = "chatglm2" self.__import_package(model_type) - # 1. convert model - fp32_bin = "ne_{}_f32.bin".format(model_type) + # check cache and quantization + output_path = "runtime_outs" + if not os.path.exists(output_path): + os.makedirs(output_path) + fp32_bin = "{}/ne_{}_f32.bin".format(output_path, model_type) + quant_bin = "{}/ne_{}_q.bin".format(output_path, model_type) + + if not_quant: + self.bin_file = fp32_bin + else: + self.bin_file = quant_bin + if use_cache and os.path.exists(self.bin_file): + return + convert_model(model_name, fp32_bin, "f32") assert os.path.exists(fp32_bin), "Fail to convert pytorch model" - # 2. quant model - quant_bin = "ne_{}_q.bin".format(model_type) - self.module.Model.quant_model(model_path = fp32_bin, out_path = quant_bin, **kwargs) + if not_quant: + print("FP32 model will be used.") + return + self.module.Model.quant_model(model_path = fp32_bin, out_path = quant_bin, **quant_kwargs) assert os.path.exists(quant_bin), "Fail to quantize model" - self.model_type = model_type - self.bin_file = quant_bin - # clean os.remove(fp32_bin) - def init_from_bin(self, model_name, model_path, **kwargs): + def init_from_bin(self, model_name, model_path, **generate_kwargs): self.__import_package(model_name) self.model = self.module.Model() - self.model.init_model(model_path, **kwargs) + if "threads" not in generate_kwargs: + threads = os.getenv("OMP_NUM_THREADS") + if threads is None: + generate_kwargs["threads"] = len(os.sched_getaffinity(0)) + else: + generate_kwargs["threads"] = int(threads) + self.model.init_model(model_path, **generate_kwargs) - def quant_model(self, model_name, model_path, out_path, **kwargs): + def quant_model(self, model_name, model_path, out_path, **quant_kwargs): self.__import_package(model_name) self.module.Model.quant_model(model_path = model_path, - out_path = out_path, **kwargs) + out_path = out_path, **quant_kwargs) - def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, **kwargs): + def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs): + max_new_tokens = generate_kwargs.get("max_new_tokens", -1) if self.model is None: self.init_from_bin(self.model_type, self.bin_file, batch_size=input_ids.shape[0], - **kwargs) + **generate_kwargs) self.generate_round = 0 elif not interactive: self.model.reinit() @@ -109,34 +127,41 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa ret = input_ids.tolist() beam_search = False - if ("num_beams" in kwargs and kwargs["num_beams"] > 1) and not \ - kwargs.get("do_sample", False): + if ("num_beams" in generate_kwargs and generate_kwargs["num_beams"] > 1) and not \ + generate_kwargs.get("do_sample", False): beam_search = True if not beam_search: # TODO support multi batch assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids." + if streamer: - if beam_search: - print("ERROR, can not use streamer when use beam search for generation!") - import sys - sys.exit(1) + assert input_ids.shape[0] == 1, "Streamer only supports batch size 1." + assert beam_search == False, "ERROR, can not use streamer when use beam search for generation! \ + Make sure that `num_beams` is set to 1." if self.generate_round == 0 and not ignore_prompt: streamer.put(input_ids) - if interactive: - self.model.reset_token_end() - while not self.is_token_end(): - out = self.model.generate(input_ids = input_ids.tolist()[0]) - if len(out) == 0: - break - streamer.put(torch.tensor([out])) - ret[0].extend(out) - streamer.end() - else: - response = self.model.generate_tokens(input_ids = input_ids.tolist()) - assert (len(ret) == len(response)) + + if interactive: + self.model.reset_token_end() + out_count = 0 + while True: + response = self.model.generate(input_ids = input_ids.tolist()) + if len(response) == 0: + break + if streamer: + streamer.put(torch.tensor([response[0]])) for i in range(len(response)): ret[i].extend(response[i]) - + if stopping_criteria is not None: + if stopping_criteria(torch.tensor(ret), None): + break + elif ret[0][-1] == self.tokenizer.eos_token_id or \ + (max_new_tokens != -1 and out_count > max_new_tokens): + break + out_count += 1 + if streamer: + streamer.end() + self.generate_round += 1 return ret diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index b6e4d7378d7..467fce54e9f 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -60,7 +60,7 @@ class Model { int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token); void reinit(); - std::vector generate(const std::vector& input_ids); + std::vector> generate(const std::vector>& input_ids); std::vector> generate_tokens(const std::vector>& input_ids); bool is_token_end() { return token_eos; } static int quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype, @@ -149,17 +149,52 @@ void Model::reinit() { generate_count = 0; } -std::vector Model::generate(const std::vector& input_ids) { +std::vector> Model::generate(const std::vector>& input_ids) { + int n_remain = params.n_predict; + std::vector> rets; + if (ctx->beam_search) { + MODEL_ASSERT(input_ids.size() == ctx->batch_size); + if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) { + fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n"); + return rets; + } + std::vector inputs; + for (int bs = 0; bs < input_ids.size(); ++bs) { + uint32_t count = 0; + model_vocab::id pad_token_id = ctx->vocab.pad_token_id; + auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), + [&pad_token_id](model_token t) { return (t != pad_token_id); }); + if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); + count = std::distance(input_ids[bs].begin(), iter); + inputs.push_back(model_input{ + /*.tokens =*/input_ids[bs].data(), + /*.n_tokens =*/(uint32_t)input_ids[bs].size(), + /*.n_prompt_tokens =*/0, + /*.n_past =*/0, + /*.n_total =*/0, + /*.request_idx =*/bs, + /*.beam_idx =*/0, + /*.padding_side =*/0, + /*n_padding =*/count, + }); + } + return post_beam_search(ctx, n_remain, inputs, params.n_threads); + } + if (input_ids.size() > 1) { + fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); + return rets; + } if (curr_input_ids.empty()) { - if (input_ids.size() > n_ctx - 4) { + if (input_ids[0].size() > n_ctx - 4) { fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, - input_ids.size(), n_ctx - 4); + input_ids[0].size(), n_ctx - 4); curr_input_ids.resize(n_ctx - 4); - std::copy(input_ids.end() - n_ctx - 4, input_ids.end(), curr_input_ids.begin()); + std::copy(input_ids[0].end() - n_ctx - 4, input_ids[0].end(), curr_input_ids.begin()); } else { - curr_input_ids = input_ids; + curr_input_ids = input_ids[0]; } } + for (auto item : curr_input_ids) { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(item); @@ -199,14 +234,7 @@ std::vector Model::generate(const std::vector& input_i curr_input_ids = {next_token_id}; generate_count++; - if (next_token_id == ctx->vocab.eos_token_id) { - token_eos = true; - } - if (params.n_predict > 0 && generate_count >= params.n_predict) { - token_eos = true; - } - - return {next_token_id}; + return {{next_token_id}}; } std::vector> Model::generate_tokens(const std::vector>& input_ids) { @@ -480,7 +508,7 @@ PYBIND11_MODULE(mistral_cpp, m) .def("generate_tokens", &Model::generate_tokens, "Generate tokens with input ids", py::arg("input_ids")) .def_static("quant_model", &Model::quant_model, "Quantize model", py::arg("model_path"), py::arg("out_path"), py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32, - py::arg("scale_dtype") = "fp32", py::arg("compute_dtype") = "ggml", py::arg("use_ggml") = false) + py::arg("scale_dtype") = "fp32", py::arg("compute_dtype") = "int8", py::arg("use_ggml") = false) .def("is_token_end", &Model::is_token_end) .def("reset_token_end", &Model::reset_token_end) .def("reinit", &Model::reinit); diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index f37979d6817..50f0de0a08f 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -142,6 +142,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): scale_dtype=quantization_config.scale_dtype, compute_dtype=quantization_config.compute_dtype, use_ggml=quantization_config.use_ggml, + not_quant=quantization_config.not_quant, + use_cache=quantization_config.use_cache, ) return model else: diff --git a/intel_extension_for_transformers/transformers/utils/quantization_config.py b/intel_extension_for_transformers/transformers/utils/quantization_config.py index ae74391f2d8..27a67cf3998 100644 --- a/intel_extension_for_transformers/transformers/utils/quantization_config.py +++ b/intel_extension_for_transformers/transformers/utils/quantization_config.py @@ -40,6 +40,8 @@ def __init__( scheme="sym", algorithm="RTN", use_ggml=False, + not_quant=False, + use_cache=False, **kwargs, ): from intel_extension_for_transformers.llm.quantization.utils import convert_dtype_2_str @@ -59,6 +61,8 @@ def __init__( self.calib_dataloader = kwargs.pop("calib_dataloader", None) self.calib_iters = kwargs.pop("calib_iters", 100) self.use_ggml = use_ggml + self.not_quant = not_quant + self.use_cache = use_cache if compute_dtype is None: self.compute_dtype = "fp32" diff --git a/tests/test_llm_runtime.py b/tests/test_llm_runtime.py index 315c848224c..75f904fff36 100644 --- a/tests/test_llm_runtime.py +++ b/tests/test_llm_runtime.py @@ -31,7 +31,10 @@ def test_llm_runtime(self): streamer = TextStreamer(tokenizer) model = AutoModel.from_pretrained(model_name, quantization_config=woq_config, use_llm_runtime=True, trust_remote_code=True) - gen_tokens = model.generate(input_ids, streamer=streamer, max_new_tokens=300) + gen_tokens = model.generate(input_ids, streamer=streamer, max_new_tokens=300, seed=1) + outputs = tokenizer.batch_decode(gen_tokens) + print(outputs) + self.assertTrue("小明" in outputs[0]) def test_beam_search(self): model_name = "/tf_dataset2/models/pytorch/gpt-j-6B" # or local path to model @@ -57,11 +60,8 @@ def test_beam_search(self): pt_generate_ids = pt_model.generate(**inputs, max_new_tokens=128, min_new_tokens=30, early_stopping=True, num_beams=4).tolist() # llm runtime fp32 - convert_model(model_name, "gptj_fp32.bin", "f32") - itrex_model = Model() - itrex_model.init_from_bin("gptj", "gptj_fp32.bin", batch_size=4, num_beams=4, - max_new_tokens=128, min_new_tokens=30, early_stopping=True, - pad_token=pad_token) + woq_config = WeightOnlyQuantConfig(not_quant=True) + itrex_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True) itrex_generate_ids = itrex_model.generate(inputs.input_ids, batch_size=4, num_beams=4, max_new_tokens=128, min_new_tokens=30, early_stopping=True, pad_token=pad_token)