-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
49 lines (42 loc) · 1.51 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import pandas as pd
import numpy as np
from absl import app
from util import han2Jamo
from flags import create_flags, FLAGS, CONST
from word_embedding import WordEmbedding
from classifier import Classifier
from document import Document
def main(_):
# init
we = WordEmbedding()
dc = Document()
cf = Classifier()
# load word embedding model
if FLAGS.we_model == 'devblog':
we_model = we.loadDevblogModel(embedding_dim = FLAGS.we_dim,
epochs = FLAGS.we_epoch,
window = FLAGS.we_window,
min_count = FLAGS.we_min_count)
elif FLAGS.we_model == 'wiki':
we_model = we.loadWikiModel()
# load classifier model
cf_model = cf.loadModel(FLAGS.cf_model)
results = [{'text': r} for r in FLAGS.predict]
is_devblog = FLAGS.we_model == 'devblog'
for i, r in enumerate(FLAGS.predict):
# preprocessing
text = han2Jamo(r) if is_devblog else r
# word embedding
df = dc.preprocessing(text, devblog=is_devblog)
vector = df.text.apply(lambda x: we.embedding(we_model, x, FLAGS.we_dim)).tolist()
if len(vector) == 0:
print('🐈 text is not valid :', r)
return
else:
# predict
results[i]['predict'] = cf.predict(cf_model, np.array(vector), FLAGS.criterion)
return results
if __name__ == '__main__':
create_flags(True)
app.run(main)