forked from michiyasunaga/qagnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
249 lines (234 loc) · 16.1 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import os
import argparse
from multiprocessing import cpu_count
from utils.convert_csqa import convert_to_entailment
from utils.convert_csqa_extract import convert_to_extraction_entailment
from utils.convert_obqa import convert_to_obqa_statement
from utils.convert_squad import convert_to_squad_statement
from utils.conceptnet import extract_english, construct_graph
from utils.grounding import create_matcher_patterns, ground
from utils.graph import generate_adj_data_from_grounded_concepts__use_LM
input_paths = {
'csqa': {
'train': './data/csqa/train_rand_split.jsonl',
'dev': './data/csqa/dev_rand_split.jsonl',
'test': './data/csqa/test_rand_split_no_answers.jsonl',
},
'csqa-extract': {
'train': './data/csqa_extract/train_rand_split.jsonl',
'dev': './data/csqa_extract/dev_rand_split.jsonl',
'test': './data/csqa_extract/test_rand_split_no_answers.jsonl',
},
'obqa': {
'train': './data/obqa/OpenBookQA-V1-Sep2018/Data/Main/train.jsonl',
'dev': './data/obqa/OpenBookQA-V1-Sep2018/Data/Main/dev.jsonl',
'test': './data/obqa/OpenBookQA-V1-Sep2018/Data/Main/test.jsonl',
},
'obqa-fact': {
'train': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/train_complete.jsonl',
'dev': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/dev_complete.jsonl',
'test': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/test_complete.jsonl',
},
'squad': {
'train': './data/squad/train.json',
'dev': './data/squad/dev.json',
},
'squad1': {
'train': './data/squad1/train.json',
'dev': './data/squad1/dev.json',
},
'cpnet': {
'csv': './data/cpnet/conceptnet-assertions-5.6.0.csv',
},
}
output_paths = {
'cpnet': {
'csv': './data/cpnet/conceptnet.en.csv',
'vocab': './data/cpnet/concept.txt',
'patterns': './data/cpnet/matcher_patterns.json',
'unpruned-graph': './data/cpnet/conceptnet.en.unpruned.graph',
'pruned-graph': './data/cpnet/conceptnet.en.pruned.graph',
},
'csqa': {
'statement': {
'train': './data/csqa/statement/train.statement.jsonl',
'dev': './data/csqa/statement/dev.statement.jsonl',
'test': './data/csqa/statement/test.statement.jsonl',
},
'grounded': {
'train': './data/csqa/grounded/train.grounded.jsonl',
'dev': './data/csqa/grounded/dev.grounded.jsonl',
'test': './data/csqa/grounded/test.grounded.jsonl',
},
'graph': {
'adj-train': './data/csqa/graph/train.graph.adj.pk',
'adj-dev': './data/csqa/graph/dev.graph.adj.pk',
'adj-test': './data/csqa/graph/test.graph.adj.pk',
},
},
'csqa-extract': {
'statement': {
'train': './data/csqa_extract/statement/train.statement.jsonl',
'dev': './data/csqa_extract/statement/dev.statement.jsonl',
'test': './data/csqa_extract/statement/test.statement.jsonl',
},
'grounded': {
'train': './data/csqa_extract/grounded/train.grounded.jsonl',
'dev': './data/csqa_extract/grounded/dev.grounded.jsonl',
'test': './data/csqa_extract/grounded/test.grounded.jsonl',
},
'graph': {
'adj-train': './data/csqa_extract/graph/train.graph.adj.pk',
'adj-dev': './data/csqa_extract/graph/dev.graph.adj.pk',
'adj-test': './data/csqa_extract/graph/test.graph.adj.pk',
},
},
'obqa': {
'statement': {
'train': './data/obqa/statement/train.statement.jsonl',
'dev': './data/obqa/statement/dev.statement.jsonl',
'test': './data/obqa/statement/test.statement.jsonl',
'train-fairseq': './data/obqa/fairseq/official/train.jsonl',
'dev-fairseq': './data/obqa/fairseq/official/valid.jsonl',
'test-fairseq': './data/obqa/fairseq/official/test.jsonl',
},
'grounded': {
'train': './data/obqa/grounded/train.grounded.jsonl',
'dev': './data/obqa/grounded/dev.grounded.jsonl',
'test': './data/obqa/grounded/test.grounded.jsonl',
},
'graph': {
'adj-train': './data/obqa/graph/train.graph.adj.pk',
'adj-dev': './data/obqa/graph/dev.graph.adj.pk',
'adj-test': './data/obqa/graph/test.graph.adj.pk',
},
},
'obqa-fact': {
'statement': {
'train': './data/obqa/statement/train-fact.statement.jsonl',
'dev': './data/obqa/statement/dev-fact.statement.jsonl',
'test': './data/obqa/statement/test-fact.statement.jsonl',
'train-fairseq': './data/obqa/fairseq/official/train-fact.jsonl',
'dev-fairseq': './data/obqa/fairseq/official/valid-fact.jsonl',
'test-fairseq': './data/obqa/fairseq/official/test-fact.jsonl',
},
},
'squad': {
'statement': {
'train': './data/squad/statement/train.statement.jsonl',
'dev': './data/squad/statement/dev.statement.jsonl',
},
'grounded': {
'train': './data/squad/grounded/train.grounded.jsonl',
'dev': './data/squad/grounded/dev.grounded.jsonl',
},
'graph': {
'adj-train': './data/squad/graph/train.graph.adj.pk',
'adj-dev': './data/squad/graph/dev.graph.adj.pk',
},
},
'squad1': {
'statement': {
'train': './data/squad1/statement/train.statement.jsonl',
'dev': './data/squad1/statement/dev.statement.jsonl',
},
'grounded': {
'train': './data/squad1/grounded/train.grounded.jsonl',
'dev': './data/squad1/grounded/dev.grounded.jsonl',
},
'graph': {
'adj-train': './data/squad1/graph/train.graph.adj.pk',
'adj-dev': './data/squad1/graph/dev.graph.adj.pk',
},
},
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--run', default=['common'], choices=['common', 'csqa', 'csqa-extract', 'squad', 'squad1', 'hswag', 'anli', 'exp', 'scitail', 'phys', 'socialiqa', 'obqa', 'obqa-fact', 'make_word_vocab'], nargs='+')
parser.add_argument('--path_prune_threshold', type=float, default=0.12, help='threshold for pruning paths')
parser.add_argument('--max_node_num', type=int, default=200, help='maximum number of nodes per graph')
parser.add_argument('-p', '--nprocs', type=int, default=cpu_count(), help='number of processes to use')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--debug', action='store_true', help='enable debug mode')
args = parser.parse_args()
if args.debug:
raise NotImplementedError()
routines = {
'common': [
{'func': extract_english, 'args': (input_paths['cpnet']['csv'], output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'])},
{'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['unpruned-graph'], False)},
{'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['pruned-graph'], True)},
{'func': create_matcher_patterns, 'args': (output_paths['cpnet']['vocab'], output_paths['cpnet']['patterns'])},
],
'csqa': [
{'func': convert_to_entailment, 'args': (input_paths['csqa']['train'], output_paths['csqa']['statement']['train'])},
{'func': convert_to_entailment, 'args': (input_paths['csqa']['dev'], output_paths['csqa']['statement']['dev'])},
{'func': convert_to_entailment, 'args': (input_paths['csqa']['test'], output_paths['csqa']['statement']['test'])},
{'func': ground, 'args': (output_paths['csqa']['statement']['train'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['train'], args.nprocs)},
{'func': ground, 'args': (output_paths['csqa']['statement']['dev'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['dev'], args.nprocs)},
{'func': ground, 'args': (output_paths['csqa']['statement']['test'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['test'], args.nprocs)},
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-train'], args.nprocs)},
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-dev'], args.nprocs)},
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-test'], args.nprocs)},
],
'csqa-extract': [
{'func': convert_to_extraction_entailment, 'args': (input_paths['csqa-extract']['train'], output_paths['csqa-extract']['statement']['train'])},
{'func': convert_to_extraction_entailment, 'args': (input_paths['csqa-extract']['dev'], output_paths['csqa-extract']['statement']['dev'])},
{'func': convert_to_extraction_entailment, 'args': (input_paths['csqa-extract']['test'], output_paths['csqa-extract']['statement']['test'])},
{'func': ground, 'args': (output_paths['csqa-extract']['statement']['train'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['patterns'], output_paths['csqa-extract']['grounded']['train'], args.nprocs)},
{'func': ground, 'args': (output_paths['csqa-extract']['statement']['dev'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['patterns'], output_paths['csqa-extract']['grounded']['dev'], args.nprocs)},
{'func': ground, 'args': (output_paths['csqa-extract']['statement']['test'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['patterns'], output_paths['csqa-extract']['grounded']['test'], args.nprocs)},
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa-extract']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa-extract']['graph']['adj-train'], args.nprocs)},
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa-extract']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa-extract']['graph']['adj-dev'], args.nprocs)},
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa-extract']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa-extract']['graph']['adj-test'], args.nprocs)},
],
'squad': [
{'func': convert_to_squad_statement, 'args': (input_paths['squad']['dev'], output_paths['squad']['statement']['dev'])},
{'func': convert_to_squad_statement, 'args': (input_paths['squad']['train'], output_paths['squad']['statement']['train'])},
# {'func': ground, 'args': (output_paths['squad']['statement']['dev'], output_paths['cpnet']['vocab'],
# output_paths['cpnet']['patterns'], output_paths['squad']['grounded']['dev'], args.nprocs)},
# {'func': ground, 'args': (output_paths['squad']['statement']['train'], output_paths['cpnet']['vocab'],
# output_paths['cpnet']['patterns'], output_paths['squad']['grounded']['train'], args.nprocs)},
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['squad']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['squad']['graph']['adj-dev'], args.nprocs)},
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['squad']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['squad']['graph']['adj-train'], args.nprocs)},
],
'squad1': [
{'func': convert_to_squad_statement, 'args': (input_paths['squad1']['dev'], output_paths['squad1']['statement']['dev'])},
{'func': convert_to_squad_statement, 'args': (input_paths['squad1']['train'], output_paths['squad1']['statement']['train'])},
# {'func': ground, 'args': (output_paths['squad1']['statement']['dev'], output_paths['cpnet']['vocab'],
# output_paths['cpnet']['patterns'], output_paths['squad1']['grounded']['dev'], args.nprocs)},
# {'func': ground, 'args': (output_paths['squad1']['statement']['train'], output_paths['cpnet']['vocab'],
# output_paths['cpnet']['patterns'], output_paths['squad1']['grounded']['train'], args.nprocs)},
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['squad1']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['squad1']['graph']['adj-dev'], args.nprocs)},
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['squad1']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['squad1']['graph']['adj-train'], args.nprocs)},
],
'obqa': [
{'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['train'], output_paths['obqa']['statement']['train'], output_paths['obqa']['statement']['train-fairseq'])},
{'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['dev'], output_paths['obqa']['statement']['dev'], output_paths['obqa']['statement']['dev-fairseq'])},
{'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['test'], output_paths['obqa']['statement']['test'], output_paths['obqa']['statement']['test-fairseq'])},
{'func': ground, 'args': (output_paths['obqa']['statement']['train'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['train'], args.nprocs)},
{'func': ground, 'args': (output_paths['obqa']['statement']['dev'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['dev'], args.nprocs)},
{'func': ground, 'args': (output_paths['obqa']['statement']['test'], output_paths['cpnet']['vocab'],
output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['test'], args.nprocs)},
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-train'], args.nprocs)},
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-dev'], args.nprocs)},
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-test'], args.nprocs)},
],
}
for rt in args.run:
for rt_dic in routines[rt]:
rt_dic['func'](*rt_dic['args'])
print('Successfully run {}'.format(' '.join(args.run)))
if __name__ == '__main__':
main()
# pass