-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
70 lines (48 loc) · 1.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib
import xgboost as xgb
import pandas as pd
import time
import sys
from features import extractFeatures, FEATURES
RANDOM_FORREST_CLASSIFIER = 'randomforrest.joblib'
XGBOOST_MODEL = 'xgboost.model'
def trainRandomForrest(train_file, feat_selection=[]):
print("Starting Training RandomForrest")
start_time = time.time()
train = pd.read_csv(train_file)
train.replace({"truthClass": {"no-clickbait":0, "clickbait":1}}, inplace = True)
truth = train['truthClass']
train.drop(columns=['id','truthClass'], axis = 1, inplace = True)
if len(feat_selection) is not 0:
train = train[feat_selection]
clf = RandomForestClassifier()
clf.fit(train, truth)
joblib.dump(clf, RANDOM_FORREST_CLASSIFIER)
print("Random Forrest training {}".format(time.time() - start_time))
return clf
def trainXGBoost(train_file, feat_selection=[]):
print("Starting Training XGBoost")
start_time = time.time()
train = pd.read_csv(train_file)
train.replace({"truthClass": {"no-clickbait":0, "clickbait":1}}, inplace = True)
truth = train['truthClass']
train.drop(columns=['id','truthClass'], axis = 1, inplace = True)
params = {
'objective': 'binary:hinge'
}
if len(feat_selection) is not 0:
train = train[feat_selection]
dtrain = xgb.DMatrix(train, label = truth)
model = xgb.train(params, dtrain)
model.save_model(XGBOOST_MODEL)
print("XGBoost training took {}".format(time.time() - start_time))
return model
def trainClassifiers():
argv = sys.argv[1:]
extractFeatures(argv[0])
trainXGBoost(FEATURES)
print('\n\n')
trainRandomForrest(FEATURES)
if __name__ == "__main__":
trainClassifiers()