diff --git a/intel_extension_for_transformers/llm/runtime/graph/README.md b/intel_extension_for_transformers/llm/runtime/graph/README.md index 53a14495c9e..2ce902a83ca 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/README.md +++ b/intel_extension_for_transformers/llm/runtime/graph/README.md @@ -235,3 +235,25 @@ Argument description of inference.py: ### 3. Tensor Parallelism cross nodes/sockets We support tensor parallelism strategy for distributed inference/training on multi-node and multi-socket. You can refer to [tensor_parallelism.md](./tensor_parallelism.md) to enable this feature. + +### 4. Chat with LLaMA2 +```python +from transformers import AutoTokenizer, TextStreamer +from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig + +model_name = "meta-llama/Llama-2-7b-chat-hf" # or local path to model +woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4") +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +streamer = TextStreamer(tokenizer) +model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True) + +while True: + print("> ", end="") + prompt = input().strip() + if prompt == "quit": + break + b_prompt = "[INST]{}[/INST]".format(prompt) # prompt template for llama2 + inputs = tokenizer(b_prompt, return_tensors="pt").input_ids + outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, + num_beams=1, max_new_tokens=512, ctx_size = 512, do_sample=True, threads=28, repetition_penalty=1.1) +``` \ No newline at end of file diff --git a/intel_extension_for_transformers/llm/runtime/graph/__init__.py b/intel_extension_for_transformers/llm/runtime/graph/__init__.py index 2ab6a34e279..aaeab8d16a7 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/__init__.py +++ b/intel_extension_for_transformers/llm/runtime/graph/__init__.py @@ -84,7 +84,6 @@ def init(self, model_name, **kwargs): # clean os.remove(fp32_bin) - def init_from_bin(self, model_name, model_path, **kwargs): self.__import_package(model_name) self.model = self.module.Model() @@ -95,6 +94,7 @@ def quant_model(self, model_name, model_path, out_path, **kwargs): self.module.Model.quant_model(model_path = model_path, out_path = out_path, **kwargs) + def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, **kwargs): if self.model is None: self.init_from_bin(self.model_type, self.bin_file, **kwargs) @@ -120,8 +120,12 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa sys.exit(1) if self.generate_round == 0 and not ignore_prompt: streamer.put(input_ids) + if interactive: + self.model.reset_token_end() while not self.is_token_end(): out = self.model.generate(input_ids = input_ids.tolist()[0]) + if len(out) == 0: + break streamer.put(torch.tensor([out])) ret[0].extend(out) streamer.end() diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index a497623024d..ed013f60d5c 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -65,6 +65,11 @@ class Model { static int quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype, const std::string& alg, int group_size, const std::string& scale_dtype, const std::string& compute_dtype, bool use_ggml); + void reset_token_end() { + token_eos = false; + curr_input_ids.clear(); + generate_count = 0; + } private: model_context* ctx = nullptr; @@ -75,6 +80,7 @@ class Model { int n_ctx = 0; std::vector last_n_tokens; bool token_eos = false; + long int generate_count = 0; model_token post_process(float* logits); model_token post_greedy_search(float* logits); @@ -133,10 +139,15 @@ void Model::reinit() { curr_input_ids.clear(); ctx->n_sample = 0; ctx->t_sample_us = 0; + generate_count = 0; } std::vector Model::generate(const std::vector& input_ids) { if (curr_input_ids.empty()) { + if (input_ids.size() > n_ctx - 4) { + fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, input_ids.size(), n_ctx - 4); + return {}; + } curr_input_ids = input_ids; } for (auto item : curr_input_ids) { @@ -161,7 +172,11 @@ std::vector Model::generate(const std::vector& input_i model_token next_token_id = post_process(logits); curr_input_ids = {next_token_id}; - if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() >= params.n_predict) { + generate_count++; + if (next_token_id == ctx->vocab.eos_token_id) { + token_eos = true; + } + if (params.n_predict > 0 && generate_count >= params.n_predict) { token_eos = true; } @@ -173,6 +188,10 @@ std::vector Model::generate_tokens(const std::vector& std::vector output_ids; if (curr_input_ids.empty()) { + if (input_ids.size() > n_ctx - 4) { + fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, input_ids.size(), n_ctx - 4); + return output_ids; + } curr_input_ids = input_ids; } @@ -203,7 +222,12 @@ std::vector Model::generate_tokens(const std::vector& model_token next_token_id = post_process(logits); curr_input_ids = {next_token_id}; output_ids.push_back(next_token_id); - if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() >= params.n_predict) { + generate_count++; + if (next_token_id == ctx->vocab.eos_token_id) { + token_eos = true; + break; + } + if (params.n_predict > 0 && generate_count >= params.n_predict) { token_eos = true; break; } @@ -383,5 +407,6 @@ PYBIND11_MODULE(mistral_cpp, m) py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32, py::arg("scale_dtype") = "fp32", py::arg("compute_dtype") = "ggml", py::arg("use_ggml") = false) .def("is_token_end", &Model::is_token_end) + .def("reset_token_end", &Model::reset_token_end) .def("reinit", &Model::reinit); }