-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtrain_helper.py
38 lines (31 loc) · 1.55 KB
/
train_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
from tensorflow.keras.optimizers import SGD, Adam, Adagrad, Adadelta, RMSprop
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.python.keras.callbacks import TensorBoard, ModelCheckpoint
class TrainHelper:
@staticmethod
def get_optimizer(optimizer):
if optimizer == "sdg":
return SGD(learning_rate=0.01, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
if optimizer == "rmsprop":
return RMSprop(learning_rate=0.01)
if optimizer == "adam":
return Adam(learning_rate=0.01)
if optimizer == "adagrad":
return Adagrad(learning_rate=0.01)
if optimizer == "adadelta":
return Adadelta(learning_rate=1.0)
@staticmethod
def get_callbacks(output_dir, model_name, optimizer, model_weigths_path):
logdir = os.path.join(output_dir, optimizer, 'logs')
chkpt_filepath = model_name + '--{epoch:02d}--{loss:.3f}--{val_loss:.3f}.h5'
callbacks = [
EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=4, verbose=1),
ModelCheckpoint(filepath=model_weigths_path, monitor='val_loss', save_best_only=True,
save_weights_only=True, verbose=1),
TensorBoard(log_dir=logdir)]
if optimizer in ["sdg", "rmsprop"]:
callbacks.append(
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, verbose=1, mode='min', min_delta=0.01,
cooldown=0, min_lr=0))
return callbacks