forked from cryingjin/AMIOK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
51 lines (43 loc) · 1.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import prep.preprocess as prep
import prep.data as dt
import prep.tokenizer as tk
import trpkg.textrank as tr
import sim.similarity as sim
import post.postprocess as post
import argparse
import model.inference as seq
from model.total_inference import *
if __name__ == '__main__':
inferencer = ModelInference()
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('-s', help="input_sentence", nargs='+', required=True)
args = parser.parse_args()
sentence = args.s[0]
print("\n[Input sentence]", sentence)
# preprocess
sentence = prep.preprocess(sentence)
print("\n[modified sentence]", sentence)
# textrank
sentence = tr.sentence_extraction(sentence)
print("\n[textrank result]", sentence)
# tokenizing
npo, ypo = tk.mecab_tokenizer(sentence)
##### 어떤 모델을 쓸까? #####
# similarity
result_type = 'sim'
answer = sim.output(ypo)
if answer == "No Result":
if len(sentence) <= 200:
answer = inferencer.inference_seq2seq(ypo)
result_type = 'dl'
else:
answer = inferencer.inference_mt5(sentence)
result_type = 'dl'
# postprocess
answer = post.postprocess(answer, result_type)
print()
print("=====" * 20)
print("\n[",result_type, "answer ]", answer)
print()
print("=====" * 20)
print()