Skip to content

Commit

Permalink
add viterbi decode to predict
Browse files Browse the repository at this point in the history
  • Loading branch information
ThanhChinhBK committed Sep 26, 2017
1 parent 97561d3 commit b6e7482
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
10 changes: 7 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _add_model(self):
self.prediction = tf.reshape(predict,
[-1, self.config["sentence_length"], self.config["num_class"]])
#self.loss = self._cost()
log_likehood, _ = tf.contrib.crf.crf_log_likelihood(
log_likehood, self.transition_params = tf.contrib.crf.crf_log_likelihood(
self.prediction, self.labels, self.sent_len)
self.loss = tf.reduce_mean(-log_likehood, name="loss")
optimizer = tf.train.AdamOptimizer(0.003)
Expand Down Expand Up @@ -185,9 +185,13 @@ def calc_total_cost(self, sentence, word_list,sent_len, labels ):
}
return self.sess.run(self.loss, feed_dict = feed_dict)

def transform(self, sentence, word_list, sent_len):
def transform(self, sentence, word_list, sent_len):
feed_dict = {self.sentence : sentence,
self.word_list : word_list,
self.sent_len: sent_len
}
return self.sess.run(self.label_predict, feed_dict=feed_dict)
logits, transition_params = self.sess.run([self.prediction, self.transition_params],
feed_dict=feed_dict)
viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(
logits, transition_params)
return viterbi_sequence
5 changes: 5 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ def _f1(config, predicts, labels, sent_length, f1_type="micro"):
ner = RNN_CNNs(config, embedd_table, char_embedd_table)
logger.info("Model Created")
f1_s = open("f1.txt", "w")
dev_prediction = ner.transform(word_index_sentences_dev_pad, char_index_dev_pad, dev_seq_length)
f1 = _f1(config, dev_prediction, label_index_sentences_dev_pad, dev_sent_len, "micro")
print("\nEvaluate:\n")
print("f1 score after {} epoch:{}\n".format(e, f1))
f1_s.write(str(f1) + "\n")
for e in range(FLAGS.epochs):
for step, (token_ids_batch, sent_len_batch,\
char_ids_batch, target_batch) in enumerate(
Expand Down

0 comments on commit b6e7482

Please sign in to comment.