Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[LLM Runtime] Enable interactive mode of python api #548

Merged
merged 15 commits into from
Oct 27, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def quant_model(self, model_name, model_path, out_path, **kwargs):
self.module.Model.quant_model(model_path = model_path,
out_path = out_path, **kwargs)


def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, **kwargs):
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, **kwargs)
Expand All @@ -120,8 +121,12 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
sys.exit(1)
if self.generate_round == 0 and not ignore_prompt:
streamer.put(input_ids)
if interactive:
self.model.reset_token_end()
while not self.is_token_end():
out = self.model.generate(input_ids = input_ids.tolist()[0])
if len(out) == 0:
break
streamer.put(torch.tensor([out]))
ret[0].extend(out)
streamer.end()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class Model {
static int quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype,
const std::string& alg, int group_size, const std::string& scale_dtype,
const std::string& compute_dtype, bool use_ggml);
void reset_token_end() {
token_eos = false;
curr_input_ids.clear();
generate_count = 0;
}

private:
model_context* ctx = nullptr;
Expand All @@ -75,6 +80,7 @@ class Model {
int n_ctx = 0;
std::vector<model_token> last_n_tokens;
bool token_eos = false;
long int generate_count = 0;

model_token post_process(float* logits);
model_token post_greedy_search(float* logits);
Expand Down Expand Up @@ -133,10 +139,15 @@ void Model::reinit() {
curr_input_ids.clear();
ctx->n_sample = 0;
ctx->t_sample_us = 0;
generate_count = 0;
}

std::vector<model_token> Model::generate(const std::vector<model_token>& input_ids) {
if (curr_input_ids.empty()) {
if (input_ids.size() > n_ctx - 4) {
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, input_ids.size(), n_ctx - 4);
return {};
}
curr_input_ids = input_ids;
}
for (auto item : curr_input_ids) {
Expand All @@ -161,7 +172,11 @@ std::vector<model_token> Model::generate(const std::vector<model_token>& input_i
model_token next_token_id = post_process(logits);
curr_input_ids = {next_token_id};

if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() >= params.n_predict) {
generate_count++;
if (next_token_id == ctx->vocab.eos_token_id) {
token_eos = true;
}
if (params.n_predict > 0 && generate_count >= params.n_predict) {
token_eos = true;
}

Expand All @@ -173,6 +188,10 @@ std::vector<model_token> Model::generate_tokens(const std::vector<model_token>&
std::vector<model_token> output_ids;

if (curr_input_ids.empty()) {
if (input_ids.size() > n_ctx - 4) {
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, input_ids.size(), n_ctx - 4);
return output_ids;
}
curr_input_ids = input_ids;
}

Expand Down Expand Up @@ -203,7 +222,12 @@ std::vector<model_token> Model::generate_tokens(const std::vector<model_token>&
model_token next_token_id = post_process(logits);
curr_input_ids = {next_token_id};
output_ids.push_back(next_token_id);
if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() >= params.n_predict) {
generate_count++;
if (next_token_id == ctx->vocab.eos_token_id) {
token_eos = true;
break;
}
if (params.n_predict > 0 && generate_count >= params.n_predict) {
token_eos = true;
break;
}
Expand Down Expand Up @@ -379,5 +403,6 @@ PYBIND11_MODULE(polyglot_cpp, m)
py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32,
py::arg("scale_dtype") = "fp32", py::arg("compute_dtype") = "ggml", py::arg("use_ggml") = false)
.def("is_token_end", &Model::is_token_end)
.def("reset_token_end", &Model::reset_token_end)
zhenwei-intel marked this conversation as resolved.
Show resolved Hide resolved
.def("reinit", &Model::reinit);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

model_name = "meta-llama/Llama-2-7b-chat-hf" # or local path to model
zhenwei-intel marked this conversation as resolved.
Show resolved Hide resolved
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
streamer = TextStreamer(tokenizer)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)

while True:
print("> ", end="")
prompt = input().strip()
if prompt == "quit":
break
b_prompt = "[INST]{}[/INST]".format(prompt)
inputs = tokenizer(b_prompt, return_tensors="pt").input_ids
outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True,
num_beams=1, max_new_tokens=1024, ctx_size = 2048, do_sample=True, threads=28, repetition_penalty=1.1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

model_name = "/mnt/disk1/data2/zhenweil/models/chatglm2-6b" # or local path to model
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
prompt = "one +one +one is what"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# import pdb; pdb.set_trace()
inputs = tokenizer(prompt, return_tensors="pt").input_ids
streamer = TextStreamer(tokenizer)

# model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
# # import pdb; pdb.set_trace()
# outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300)

prompt = "My name is Jerry, I love to learn new things."
# prompt = "tell me about Intel."
inputs = tokenizer(prompt, return_tensors="pt").input_ids
# outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300)

from intel_extension_for_transformers.llm.runtime.graph import Model
model = Model()
model.init_from_bin("chatglm2", "ne_chatglm2_q.bin",
num_beams=1, max_new_tokens=512, ctx_size = 512, do_sample=True, threads=28, repetition_penalty=1.1) # n_keep=4, ctx_size = 15, n_discard=1 temperature=0.001, top_k=1, top_p=0.95,

count = 1
while True:
print(">", end="")
prompt = input()
b_prompt = "[Round {}]\n\n问:{}\n\n答:".format(count, prompt)
inputs = tokenizer(b_prompt, return_tensors="pt").input_ids
outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True)
count += 1
zhenwei-intel marked this conversation as resolved.
Show resolved Hide resolved
Loading