-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathinstructions_YAGO.py
111 lines (81 loc) · 4.07 KB
/
instructions_YAGO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import numpy as np
def is_json(myjson):
try:
json.loads(myjson)
except ValueError:
return False
return True
ent2txt = {}
with open("data/YAGO3-10/entity2text.txt", "r", encoding = "utf-8") as f:
lines = f.readlines()
for line in lines:
tmp = line.strip().split("\t")
#comma_idx = tmp[1].find(",")
ent2txt[tmp[0]] = tmp[1]
rel2txt = {}
with open("data/YAGO3-10/relation2text.txt", "r", encoding = "utf-8") as f:
lines = f.readlines()
for line in lines:
tmp = line.strip().split("\t")
rel2txt[tmp[0]] = tmp[1]
ent_list = []
for ent in ent2txt:
ent_list.append(ent)
tail_lines_to_write_glm = []
tail_lines_to_write_llama_lora = []
head_lines_to_write_glm = []
head_lines_to_write_llama_lora = []
rel_lines_to_write_glm = []
rel_lines_to_write_llama_lora = []
with open("data/YAGO3-10/train.tsv", "r", encoding = "utf-8") as f:
lines = f.readlines()
for line in lines:
tmp = line.strip().split("\t")
prompt = ent2txt[tmp[0]] + " " + rel2txt[tmp[1]]
tmp_str = "{\"prompt\": \"" + prompt + "\", \"response\": \""+ ent2txt[tmp[2]] + "\"}"
if is_json(tmp_str):
tail_lines_to_write_glm.append(tmp_str + "\n")
tmp_str = "{\n\"instruction\": \"" + prompt + "\",\n \"input\": \"\",\n \"output\": \""+ ent2txt[tmp[2]] + "\"\n}"
tail_lines_to_write_llama_lora.append(tmp_str)
#print(len(lines_to_write_glm))
prompt = "What/Who/When/Where/Why" + " " + rel2txt[tmp[1]] + " " + ent2txt[tmp[2]] + "?"
tmp_str = "{\"prompt\": \"" + prompt + "\", \"response\": \""+ ent2txt[tmp[0]] + "\"}"
if is_json(tmp_str):
head_lines_to_write_glm.append(tmp_str + "\n")
tmp_str = "{\n\"instruction\": \"" + prompt + "\",\n \"input\": \"\",\n \"output\": \""+ ent2txt[tmp[0]] + "\"\n}"
head_lines_to_write_llama_lora.append(tmp_str)
prompt = "What is the relationship between" + " " + ent2txt[tmp[0]] + " and " + ent2txt[tmp[2]] + "?"
options = "|".join([rel2txt[key] for key in rel2txt])
easy_prompt = prompt + " Please choose your answer from: " + options + "."
tmp_str = "{\"prompt\": \"" + easy_prompt + "\", \"response\": \""+ rel2txt[tmp[1]] + "\"}"
if is_json(tmp_str):
rel_lines_to_write_glm.append(tmp_str + "\n")
tmp_str = "{\n\"instruction\": \"" + easy_prompt + "\",\n \"input\": \"\",\n \"output\": \""+ rel2txt[tmp[1]] + "\"\n}"
rel_lines_to_write_llama_lora.append(tmp_str)
with open("data/YAGO3-10/train_instructions_glm_tail.json", "w", encoding = "utf-8") as f:
f.writelines(tail_lines_to_write_glm)
with open("data/YAGO3-10/train_instructions_glm_head.json", "w", encoding = "utf-8") as f:
f.writelines(head_lines_to_write_glm)
lines_to_write_glm_merge = tail_lines_to_write_glm + head_lines_to_write_glm
np.random.seed(42)
np.random.shuffle(lines_to_write_glm_merge)
with open("data/YAGO3-10/train_instructions_glm_merge.json", "w", encoding = "utf-8") as f:
f.writelines(lines_to_write_glm_merge)
with open("data/YAGO3-10/train_instructions_llama_tail.json", "w", encoding = "utf-8") as f:
tmp_str = "[\n" + ",\n".join(tail_lines_to_write_llama_lora) +"]"
f.write(tmp_str)
with open("data/YAGO3-10/train_instructions_llama_head.json", "w", encoding = "utf-8") as f:
tmp_str = "[\n" + ",\n".join(head_lines_to_write_llama_lora) +"]"
f.write(tmp_str)
lines_to_write_llama_lora = tail_lines_to_write_llama_lora + head_lines_to_write_llama_lora
np.random.seed(42)
np.random.shuffle(lines_to_write_llama_lora)
with open("data/YAGO3-10/train_instructions_llama_merge.json", "w", encoding = "utf-8") as f:
tmp_str = "[\n" + ",\n".join(lines_to_write_llama_lora) +"]"
f.write(tmp_str)
with open("data/YAGO3-10/train_instructions_glm_rel.json", "w", encoding = "utf-8") as f:
f.writelines(rel_lines_to_write_glm)
with open("data/YAGO3-10/train_instructions_llama_rel.json", "w", encoding = "utf-8") as f:
tmp_str = "[\n" + ",\n".join(rel_lines_to_write_llama_lora) +"]"
f.write(tmp_str)