-
Notifications
You must be signed in to change notification settings - Fork 3
/
predict-linear.py
executable file
·62 lines (49 loc) · 1.77 KB
/
predict-linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python3
from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier
from emoji_data import load
from features import doc_to_ngrams, preprocess
from argparse import ArgumentParser
from cmdline import add_args
ap = ArgumentParser()
add_args(ap, ('general', 'preproc', 'linear', 'test'))
opt = ap.parse_args()
if opt.class_weight:
opt.class_weight = "balanced"
else:
opt.class_weight = None
from logging import debug, info, basicConfig
basicConfig(level=opt.log_level,
format='%(asctime)s %(message)s')
data_trn = load(opt.input_prefix)
data_tst = load(opt.test_prefix)
docs_trn, v, _ = doc_to_ngrams(data_trn.docs, min_df=opt.min_df,
cache_dir = opt.cache_dir,
dim_reduce = opt.dim_reduce,
c_ngmin = opt.c_ngmin,
c_ngmax = opt.c_ngmax,
w_ngmin = opt.w_ngmin,
w_ngmax = opt.w_ngmax,
lowercase = opt.lowercase)
docs_tst = preprocess(data_tst.docs,
c_ngmin=opt.c_ngmin, c_ngmax=opt.c_ngmax,
w_ngmin=opt.w_ngmin, w_ngmax=opt.w_ngmax,
lowercase=opt.lowercase)
docs_tst = v.transform(docs_tst)
if opt.classifier == 'lr':
from sklearn.linear_model import LogisticRegression
m = LogisticRegression(dual=True, C=opt.C, verbose=0,
class_weight=opt.class_weight)
else:
from sklearn.svm import LinearSVC
m = LinearSVC(dual=True, C=opt.C, verbose=0,
class_weight=opt.class_weight)
if opt.mult_class == 'ovo':
mc = OneVsOneClassifier
else:
mc = OneVsRestClassifier
m = mc(m, n_jobs=opt.n_jobs)
m.fit(docs_trn, data_trn.labels)
pred = m.predict(docs_tst)
for lab in pred:
print(lab)