forked from ymcui/Chinese-LLaMA-Alpaca
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_hf.py
151 lines (136 loc) · 5.82 KB
/
inference_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from peft import PeftModel
import argparse
import json, os
parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model")
parser.add_argument('--tokenizer_path',default=None,type=str)
parser.add_argument('--data_file',default=None, type=str,help="file that contains instructions (one instruction per line).")
parser.add_argument('--with_prompt',action='store_true')
parser.add_argument('--interactive',action='store_true')
parser.add_argument('--predictions_file', default='./predictions.json', type=str)
args = parser.parse_args()
generation_config = dict(
temperature=0.2,
top_k=40,
top_p=0.9,
do_sample=True,
num_beams=1,
repetition_penalty=1.3,
max_new_tokens=400
)
# The prompt template below is taken from llama.cpp
# and is slightly different from the one used in training.
# But we find it gives better results
prompt_input = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n\n{instruction}\n\n### Response:\n\n"
)
sample_data = ["为什么要减少污染,保护环境?"]
def generate_prompt(instruction, input=None):
if input:
instruction = instruction + '\n' + input
return prompt_input.format_map({'instruction': instruction})
if __name__ == '__main__':
load_type = torch.float16
if torch.cuda.is_available():
device = torch.device(0)
else:
device = torch.device('cpu')
if args.tokenizer_path is None:
args.tokenizer_path = args.lora_model
if args.lora_model is None:
args.tokenizer_path = args.base_model
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path)
base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
load_in_8bit=False,
torch_dtype=load_type,
low_cpu_mem_usage=True,
)
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
if model_vocab_size!=tokenzier_vocab_size:
assert tokenzier_vocab_size > model_vocab_size
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenzier_vocab_size)
if args.lora_model is not None:
print("loading peft model")
model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type)
else:
model = base_model
if device==torch.device('cpu'):
model.float()
# test data
if args.data_file is None:
examples = sample_data
else:
with open(args.data_file,'r') as f:
examples = [l.strip() for l in f.readlines()]
print("first 10 examples:")
for example in examples[:10]:
print(example)
model.to(device)
model.eval()
with torch.no_grad():
if args.interactive:
while True:
raw_input_text = input("Input:")
if len(raw_input_text.strip())==0:
break
if args.with_prompt:
input_text = generate_prompt(instruction=raw_input_text)
else:
input_text = raw_input_text
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**generation_config
)
s = generation_output[0]
output = tokenizer.decode(s,skip_special_tokens=True)
if args.with_prompt:
response = output.split("### Response:")[1].strip()
else:
response = output
print("Response: ",response)
print("\n")
else:
results = []
for index, example in enumerate(examples):
if args.with_prompt is True:
input_text = generate_prompt(instruction=example)
else:
input_text = example
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**generation_config
)
s = generation_output[0]
output = tokenizer.decode(s,skip_special_tokens=True)
if args.with_prompt:
response = output.split("### Response:")[1].strip()
else:
response = output
print(f"======={index}=======")
print(f"Input: {example}\n")
print(f"Output: {response}\n")
results.append({"Input":input_text,"Output":response})
dirname = os.path.dirname(args.predictions_file)
os.makedirs(dirname,exist_ok=True)
with open(args.predictions_file,'w') as f:
json.dump(results,f,ensure_ascii=False,indent=2)
with open(dirname+'/generation_config.json','w') as f:
json.dump(generation_config,f,ensure_ascii=False,indent=2)