Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime] Enable interactive mode of python api (#548)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenwei-intel authored Oct 27, 2023
1 parent 8cf36a1 commit 6e32ca6
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
22 changes: 22 additions & 0 deletions intel_extension_for_transformers/llm/runtime/graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,25 @@ Argument description of inference.py:
### 3. Tensor Parallelism cross nodes/sockets

We support tensor parallelism strategy for distributed inference/training on multi-node and multi-socket. You can refer to [tensor_parallelism.md](./tensor_parallelism.md) to enable this feature.

### 4. Chat with LLaMA2
```python
from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

model_name = "meta-llama/Llama-2-7b-chat-hf" # or local path to model
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
streamer = TextStreamer(tokenizer)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)

while True:
print("> ", end="")
prompt = input().strip()
if prompt == "quit":
break
b_prompt = "[INST]{}[/INST]".format(prompt) # prompt template for llama2
inputs = tokenizer(b_prompt, return_tensors="pt").input_ids
outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True,
num_beams=1, max_new_tokens=512, ctx_size = 512, do_sample=True, threads=28, repetition_penalty=1.1)
```
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def init(self, model_name, **kwargs):
# clean
os.remove(fp32_bin)


def init_from_bin(self, model_name, model_path, **kwargs):
self.__import_package(model_name)
self.model = self.module.Model()
Expand All @@ -95,6 +94,7 @@ def quant_model(self, model_name, model_path, out_path, **kwargs):
self.module.Model.quant_model(model_path = model_path,
out_path = out_path, **kwargs)


def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, **kwargs):
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, **kwargs)
Expand All @@ -120,8 +120,12 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
sys.exit(1)
if self.generate_round == 0 and not ignore_prompt:
streamer.put(input_ids)
if interactive:
self.model.reset_token_end()
while not self.is_token_end():
out = self.model.generate(input_ids = input_ids.tolist()[0])
if len(out) == 0:
break
streamer.put(torch.tensor([out]))
ret[0].extend(out)
streamer.end()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class Model {
static int quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype,
const std::string& alg, int group_size, const std::string& scale_dtype,
const std::string& compute_dtype, bool use_ggml);
void reset_token_end() {
token_eos = false;
curr_input_ids.clear();
generate_count = 0;
}

private:
model_context* ctx = nullptr;
Expand All @@ -75,6 +80,7 @@ class Model {
int n_ctx = 0;
std::vector<model_token> last_n_tokens;
bool token_eos = false;
long int generate_count = 0;

model_token post_process(float* logits);
model_token post_greedy_search(float* logits);
Expand Down Expand Up @@ -133,10 +139,15 @@ void Model::reinit() {
curr_input_ids.clear();
ctx->n_sample = 0;
ctx->t_sample_us = 0;
generate_count = 0;
}

std::vector<model_token> Model::generate(const std::vector<model_token>& input_ids) {
if (curr_input_ids.empty()) {
if (input_ids.size() > n_ctx - 4) {
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, input_ids.size(), n_ctx - 4);
return {};
}
curr_input_ids = input_ids;
}
for (auto item : curr_input_ids) {
Expand All @@ -161,7 +172,11 @@ std::vector<model_token> Model::generate(const std::vector<model_token>& input_i
model_token next_token_id = post_process(logits);
curr_input_ids = {next_token_id};

if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() >= params.n_predict) {
generate_count++;
if (next_token_id == ctx->vocab.eos_token_id) {
token_eos = true;
}
if (params.n_predict > 0 && generate_count >= params.n_predict) {
token_eos = true;
}

Expand All @@ -173,6 +188,10 @@ std::vector<model_token> Model::generate_tokens(const std::vector<model_token>&
std::vector<model_token> output_ids;

if (curr_input_ids.empty()) {
if (input_ids.size() > n_ctx - 4) {
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, input_ids.size(), n_ctx - 4);
return output_ids;
}
curr_input_ids = input_ids;
}

Expand Down Expand Up @@ -203,7 +222,12 @@ std::vector<model_token> Model::generate_tokens(const std::vector<model_token>&
model_token next_token_id = post_process(logits);
curr_input_ids = {next_token_id};
output_ids.push_back(next_token_id);
if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() >= params.n_predict) {
generate_count++;
if (next_token_id == ctx->vocab.eos_token_id) {
token_eos = true;
break;
}
if (params.n_predict > 0 && generate_count >= params.n_predict) {
token_eos = true;
break;
}
Expand Down Expand Up @@ -383,5 +407,6 @@ PYBIND11_MODULE(mistral_cpp, m)
py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32,
py::arg("scale_dtype") = "fp32", py::arg("compute_dtype") = "ggml", py::arg("use_ggml") = false)
.def("is_token_end", &Model::is_token_end)
.def("reset_token_end", &Model::reset_token_end)
.def("reinit", &Model::reinit);
}

0 comments on commit 6e32ca6

Please sign in to comment.