-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
executable file
·85 lines (75 loc) · 2.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
import argparse
import sys
from keras.models import Sequential
from keras.layers import Dense, Flatten
from ncaa_predict.data_loader import load_data_multiyear, N_PLAYERS, N_FEATURES
from ncaa_predict.util import list_arg
DEFAULT_BATCH_SIZE = 10000
DEFAULT_STEPS = sys.maxsize
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch_size",
"-b",
default=DEFAULT_BATCH_SIZE,
type=int,
help="The training batch size. Smaller numbers will train faster but "
"may not converge. (default: %(default)s)",
)
parser.add_argument(
"--model-out",
"-o",
default=None,
help="File to save the model to. This will be an entire Keras model, "
"which can be loaded and used without needing to keep track of the "
"architecture. Warning: Keras will overwrite existing models. "
"(default: don't save)",
)
parser.add_argument(
"--steps",
"-s",
default=DEFAULT_STEPS,
type=int,
help="The maximum number of training steps. Note that you can stop "
"training at any time and save the output with ctrl+c. (default: "
"%(default)s)",
)
parser.add_argument(
"--train-years",
"-y",
default=list(range(2002, 2017)),
type=list_arg(type=int, container=frozenset),
help="A comma-separated list of years to train on.",
)
args = parser.parse_args()
model = Sequential(
[
Flatten(),
Dense(16, activation="relu", kernel_regularizer="L1L2"),
Dense(2, activation="softmax"),
]
)
model.compile(
loss="categorical_crossentropy",
optimizer="rmsprop",
metrics=["accuracy", "AUC", "Precision", "Recall"],
)
features, labels = load_data_multiyear(args.train_years)
try:
model.fit(
x=features,
y=labels,
batch_size=args.batch_size,
epochs=args.steps // args.batch_size,
shuffle=True,
validation_split=0.1,
)
except KeyboardInterrupt:
print("Stopped training due to keyboard interrupt")
if args.model_out is not None:
model.save(args.model_out)
# Workaround for TensorFlow bug:
# https://github.com/tensorflow/tensorflow/issues/3388
import gc
gc.collect()