-
Notifications
You must be signed in to change notification settings - Fork 7
/
train_model.py
111 lines (92 loc) · 2.64 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
Copyright (c) 2021, FireEye, Inc.
Copyright (c) 2021 Giorgio Severi
This script allows the user to train the models used in the paper.
To train the LightGBM and EmberNN models on EMBER:
`python train_model.py -m lightgbm -d ember`
`python train_model.py -m embernn -d ember`
To train the Random Forest model on Contagio PDFs:
`python train_model.py -m pdfrf -d ogcontagio`
To train the Linear SVM classifier on Drebin:
`python train_model.py -m linearsvm -d drebin`
"""
import random
import argparse
import numpy as np
import tensorflow as tf
from mw_backdoor import constants
from mw_backdoor import data_utils
from mw_backdoor import model_utils
def train(args):
# Unpacking
model_id = args['model']
dataset = args['dataset']
seed = args['seed']
save_dir = args['save_dir']
save_file = args['save_file']
if not save_dir:
save_dir = constants.SAVE_MODEL_DIR
if not save_file:
save_file = dataset + '_' + model_id
# Set random seeds
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
# Load data
x_train, y_train, x_test, y_test = data_utils.load_dataset(dataset=dataset)
print(
'Dataset shapes:\n'
'\tTrain x: {}\n'
'\tTrain y: {}\n'
'\tTest x: {}\n'
'\tTest y: {}\n'.format(
x_train.shape, y_train.shape, x_test.shape, y_test.shape
)
)
# Train model
model = model_utils.train_model(
model_id=model_id,
x_train=x_train,
y_train=y_train,
)
# Save trained model
model_utils.save_model(
model_id=model_id,
model=model,
save_path=save_dir,
file_name=save_file
)
# Evaluation
print('Evaluation of model: {} on dataset: {}'.format(model_id, dataset))
model_utils.evaluate_model(model, x_test, y_test)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
default="lightgbm",
choices=["lightgbm", "embernn", "pdfrf", "linearsvm"],
help="model type"
)
parser.add_argument(
"-d",
"--dataset",
default="ember",
choices=["ember", "pdf", "ogcontagio", "drebin"],
help="model type"
)
parser.add_argument(
"--save_file",
default='',
type=str,
help="file name of the saved model"
)
parser.add_argument(
"--save_dir",
default='',
type=str,
help="directory containing saved models"
)
parser.add_argument("--seed", type=int, default=42, help="random seed")
arguments = vars(parser.parse_args())
train(arguments)