-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_demo.py
39 lines (30 loc) · 1.27 KB
/
inference_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import torch
from transformers import AutoTokenizer
from inference import load_quantized_model
from expert_weight import replace_with_dynamic_rank
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.cuda.is_available()
kwargs = {"device_map": 'auto',
"torch_dtype": "torch.float16"}
######## Input your save_dir of quantized model########
save_dir = ""
model = load_quantized_model(save_dir, kwargs)
######### Choose if you want to use dynamic pruning or not ##########
args = None
model = replace_with_dynamic_rank(model, args, block_range=10)
######### Choose if you want to use dynamic pruning or not ##########
tokenizer = AutoTokenizer.from_pretrained(save_dir)
prompt = "You are a writer. Please write a short story about two llamas in a forest."
prompt_template=f'''{prompt}'''
inputs = tokenizer(prompt_template, return_tensors="pt")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
inputs.input_ids = inputs.input_ids.to(device)
inputs.attention_mask = inputs.attention_mask.to(device)
# Generate
outputs = model.generate(inputs.input_ids,
max_new_tokens=512,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
)
print(tokenizer.decode(outputs[0]))