-
Notifications
You must be signed in to change notification settings - Fork 9
/
reweight_ensemble.py
62 lines (56 loc) · 2.17 KB
/
reweight_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#more_verbose_outf.write('Events\tTemplates\tTemplate Score\tRetEdit\tRetEdit Score\tMonte Carlo\tMcts score\tFSM\tVanilla\n')
# fh = open('drl_even_more_verbose_output.txt')
# verbose_outf = open('verbose_output.txt', 'w')
# sent_outf = open('new_output.txt', 'w')
import sys
fh = open(sys.argv[1])
verbose_outf = open(sys.argv[2].split('.')[0] + '_verbose.' + sys.argv[2].split('.')[1], 'w')
sent_outf = open(sys.argv[2], 'w')
retedit_edit_dist = []
retedit_sents = []
template_scores = []
template_sents = []
mcts_scores = []
mcts_sents = []
fsm_scores = []
fsm_sents = []
vanilla_sents = []
lines = fh.read().splitlines()[1:]
for line in lines:
parts = line.split('\t')
retedit_edit_dist.append(float(parts[4]))
retedit_sents.append(parts[3])
template_scores.append(float(parts[2]))
template_sents.append(parts[1])
mcts_scores.append(float(parts[6]))
mcts_sents.append(parts[5])
if parts[7] == '<pad>':
fsm_scores.append(0)
else:
fsm_scores.append(1)
fsm_sents.append(parts[7])
vanilla_sents.append(parts[8])
for i in range(len(lines)):
if float(retedit_edit_dist[i]) < 0.1:
#print ("RETEDIT: " + retedit_sents[i])
verbose_outf.write("RETEDIT: " + retedit_sents[i] + '\n')
sent_outf.write(retedit_sents[i] + '\n')
else:
if template_scores[i] < 0.3:
#print ("TEMPLATES: " + template_sents[i])
verbose_outf.write("TEMPLATES: " + template_sents[i] + '\n')
sent_outf.write(template_sents[i] + '\n')
else:
if mcts_scores[i] > 0.2:
#print ("MCTS: " + mcts_sents[i])
verbose_outf.write("MCTS: " + mcts_sents[i] + '\n')
sent_outf.write(mcts_sents[i] + '\n')
else:
if fsm_scores[i] == 1:
#print ("FSM: " + fsm_sents[i])
verbose_outf.write("FSM: " + fsm_sents[i] + '\n')
sent_outf.write(fsm_sents[i] + '\n')
else:
#print ("VANILLA: " + vanilla_sents[i])
verbose_outf.write("VANILLA: " + vanilla_sents[i] + '\n')
sent_outf.write(vanilla_sents[i] + '\n')