-
Notifications
You must be signed in to change notification settings - Fork 0
/
pruning_module.py
157 lines (133 loc) · 5.91 KB
/
pruning_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import sys
import json
import yaml
import logging
import random
from tqdm import tqdm
from pathlib import Path
from pruning_model import PruningModel
class PruningModule:
def __init__(self, config_path="config.yml"):
self.config = self.get_config(config_path)
self.model = PruningModel(self.config)
self.model_loaded = False
def get_config(self, path):
"""Load the config dict from the given .yml file."""
with open(path, "r") as fp:
config = yaml.safe_load(fp)
return config
def train(self, train_path, dev_path):
"""Train the model."""
# train model
self.model.train(train_path, dev_path)
def inference(self, input_path, output_path):
"""Run inference."""
# load model
self._load()
# load data
with open(input_path, "r") as fp:
data = json.load(fp)
# run inference
for instance in tqdm(data):
clocq_linkings = instance["kb_item_tuple"]
instance["clocq_linkings"] = clocq_linkings
entity_linkings, predicted_mentions = self.get_entity_linkings(instance["question"], clocq_linkings, self.config["clocq_k"])
instance["predicted_mentions"] = predicted_mentions
instance["entities"] = [linking["item"]["id"] for linking in entity_linkings]
# store results
with open(output_path, "w") as fp:
fp.write(json.dumps(data, indent=4))
def inference_on_question(self, question):
# load model (if not done already)
self._load()
predicted_mentions = self.model.inference(question)
return predicted_mentions
def get_entity_linkings(self, question, linkings, k="AUTO"):
"""Prune the entity linkings provided by the original CLOCQ method using the predicted mentions."""
# load model
self._load()
# predict the relevant mentions (via seq2seq model)
predicted_mentions = self.model.inference(question)
predicted_mentions = set([mention.lower() for mention in predicted_mentions])
# prune predicate linkings
entity_linkings = [linking for linking in linkings if linking["item"]["id"][0] == "Q"]
# prune irrelevant entity linkings
output_linkings = list()
for linking in entity_linkings:
# apply k (don't consider rank-5 results if k=1)
if isinstance(k, int) and linking["rank"] >= k:
continue
else:
# check for exact match
if linking["question_word"].lower() in predicted_mentions:
output_linkings.append({
"item": linking["item"],
"mention": linking["question_word"]
})
else:
# relaxed match: linking mention appears in predicted mention
for mention in predicted_mentions:
if linking["question_word"].lower() in mention:
output_linkings.append({
"item": linking["item"],
"mention": linking["question_word"]
})
# relaxed match: predicted mention appears in linking mention
for mention in predicted_mentions:
if mention in linking["question_word"].lower():
output_linkings.append({
"item": linking["item"],
"mention": linking["question_word"]
})
return output_linkings, list(predicted_mentions)
def get_relation_linkings(self, question, linkings, top_ranked=True):
"""Prune the relation linkings provided by the original CLOCQ method."""
# prune entity linkings
relation_linkings = [
(kb_item["item"], kb_item["question_word"])
for kb_item in linkings
if kb_item["item"]["id"] and kb_item["item"]["id"][0] == "P"
]
# prune lower ranked relations
mentions = set()
output_linkings = list() # output
linkings_dict = dict() # dictionary to keep track of linkings per mention
for relation, mention in relation_linkings:
# skip 2nd ranked linking for same mention
if top_ranked and mention in linkings_dict:
continue
mentions.add(mention)
linkings_dict[mention] = True
output_linkings.append({
"item": relation,
"mention": mention
})
return output_linkings, list(mentions)
def _load(self):
"""Load the model."""
if not self.model_loaded:
self.model.load()
self.model.set_eval_mode()
self.model_loaded = True
#######################################################################################################################
#######################################################################################################################
if __name__ == "__main__":
if len(sys.argv) < 3:
raise Exception(
"Usage: python pruning_module.py --train <PATH_TO_TRAIN> <PATH_TO_DEV> [<PATH_TO_CONFIG>]\nOR python pruning_module.py --inference <PATH_TO_INPUT> <PATH_TO_OUTPUT> [<PATH_TO_CONFIG>]"
)
# load params
function = sys.argv[1]
config_path = sys.argv[4] if len(sys.argv) > 4 else "config.yml"
pruning_module = PruningModule(config_path)
# train: train model
if function == "--train":
train_path = sys.argv[2]
dev_path = sys.argv[3]
pruning_module.train(train_path, dev_path)
# inference: add predictions to data
elif function == "--inference":
input_path = sys.argv[2]
output_path = sys.argv[3]
pruning_module.inference(input_path, output_path)