-
Notifications
You must be signed in to change notification settings - Fork 18
/
test_crcnn.py
107 lines (84 loc) · 3.4 KB
/
test_crcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
'''
Created on 1 March 2018
@author: Bhanu
'''
import tensorflow as tf
import os
import numpy as np
import yaml
from train_crcnn import load_sents_data_semeval2010, build_model,\
build_data_streams, Vocab
import pickle
from sklearn.metrics.classification import f1_score, classification_report
import argparse
import sys
FLAGS = None
def main(_):
test_model = FLAGS.model_name
config_file = FLAGS.config_file
with open(config_file, 'r') as rf:
params = yaml.load(rf)
seed = params.get('seed')
tf.set_random_seed(seed)
test_data_filename = params.get('test_file')
data_dir = params.get('data_dir')
model_dir = params.get('model_dir')
print("loading data...", flush=True)
test_data_file = os.path.join(data_dir, test_data_filename)
dftest = load_sents_data_semeval2010(test_data_file, testset=True)
with open(os.path.join(model_dir, params.get('label_encoder_file')), 'rb') as rf:
le = pickle.load(rf)
print(le.classes_)
#build pos vocab
print("loading vocab...", flush=True)
with open(os.path.join(model_dir, params.get('vocab_file')), 'rb') as rf:
vocab = pickle.load(rf)
is_test_labels = dftest.class_.any()
# build input data streams
teststream = build_data_streams(dftest, vocab.dict,
params.get('sent_length'), le)
labels = teststream.label
if labels is None:
labels = np.zeros(teststream.sent.shape[0])
#build model
mdl = build_model(params)
test_feed_dict = {
mdl.sent: teststream.sent,
mdl.label: labels,
mdl.ent1_dist: teststream.ent1_dist,
mdl.ent2_dist: teststream.ent2_dist,
mdl.dropout_keep_proba: 1.0,
mdl.batch_size: teststream.sent.shape[0]
}
#run the graph
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess, os.path.join(model_dir, test_model))
print("Restored session from %s"%test_model)
pred_probas, preds = sess.run([mdl.pred_probas, mdl.preds], test_feed_dict)
#print scores, if test_labels known
if is_test_labels is not None:
l = teststream.label
p = preds
class_int_labels = list(range(len(le.classes_)))
target_names=le.classes_
eval_score = (f1_score(l, p, average='micro'),
f1_score(l, p, average='macro')
)
print("EVAL f1_micro {:g} f1_macro {:g}"
.format(eval_score[0], eval_score[1]), flush=True)
print("Classification Report: \n%s"%
classification_report(l, p,
labels=class_int_labels,
target_names=target_names,
), flush=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default=None,
help='Checkpoint Prefix of the model to be tested')
parser.add_argument('--config_file', type=str, default=None,
help='Full path of the config file')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)