-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
41 lines (34 loc) · 1.13 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import pandas as pd
import numpy as np
from absl import app
from util import han2Jamo
from flags import create_flags, FLAGS, CONST
from word_embedding import WordEmbedding
from classifier import Classifier
from document import Document
def main(_):
# init
we = WordEmbedding()
dc = Document()
cf = Classifier()
# load data
docs = dc.getDocs(labeled_only=True)
# load word embedding model
if FLAGS.we_model == 'devblog':
we_model = we.loadDevblogModel(embedding_dim = FLAGS.we_dim,
epochs = FLAGS.we_epoch,
window = FLAGS.we_window,
min_count = FLAGS.we_min_count)
# han2jamo
docs.text = docs.text.apply(han2Jamo)
elif FLAGS.we_model == 'wiki':
we_model = we.loadWikiModel()
# word embedding
docs.vector = docs.text.apply(lambda x: we.embedding(we_model, x))
# training
cf_model = cf.train(docs, './checkpoint')
cf.saveModel(cf_model, FLAGS.cf_model)
if __name__ == '__main__':
create_flags()
app.run(main)