-
Notifications
You must be signed in to change notification settings - Fork 8
/
inference_kopa.py
95 lines (83 loc) · 3.13 KB
/
inference_kopa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import json
import torch
import transformers
from peft import PeftModel
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
base_path = 'YOUR LLM PATH'
prompt_template = """
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
Given a triple from a knowledge graph. Each triple consists of a head entity, a relation, and a tail entity. Please determine the correctness of the triple and response True or False.
### Input:
{}
### Response:
"""
def load_test_dataset(path):
test_dataset = json.load(open(path, "r"))
return test_dataset
if __name__ == "__main__":
cuda = "cuda:0"
lora_weights = "YOUR SAVE PATH"
test_data_path = "data/UMLS-test.json"
embedding_path = "{}/embeddings.pth".format(lora_weights)
test_dataset = load_test_dataset(test_data_path)
kg_embeddings = torch.load(embedding_path).to(cuda)
tokenizer = LlamaTokenizer.from_pretrained(base_path)
model = LlamaForCausalLM.from_pretrained(
base_path,
torch_dtype=torch.float16
).to(cuda)
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
).to(cuda)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
model = model.eval()
result = []
for data in test_dataset:
ent = data["input"]
ans = data["output"]
ids = data["embedding_ids"]
ids = torch.LongTensor(ids).reshape(1, -1).to(cuda)
prefix = kg_embeddings(ids)
prompt = prompt_template.format(ent)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids.to(cuda)
token_embeds = model.model.model.embed_tokens(input_ids)
input_embeds = torch.cat((prefix, token_embeds), dim=1)
generate_ids = model.generate(
inputs_embeds=input_embeds,
max_new_tokens=16
)
context = tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
response = response.replace(context, "").strip()
print(response + '\n')
result.append(
{
"answer": ans,
"predict": response
}
)
answer = []
predict = []
for data in result:
if "True" in data["answer"]:
answer.append(1)
else:
answer.append(0)
if "True" in data["predict"]:
predict.append(1)
else:
predict.append(0)
acc = accuracy_score(y_true=answer, y_pred=predict)
p = precision_score(y_true=answer, y_pred=predict)
r = recall_score(y_true=answer, y_pred=predict)
f1 = f1_score(y_true=answer, y_pred=predict)
print(acc, p, r, f1)