-
Notifications
You must be signed in to change notification settings - Fork 98
/
training.py
114 lines (89 loc) · 4.43 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
__author__ = "solivr"
__license__ = "GPL"
import logging
logging.getLogger("tensorflow").setLevel(logging.INFO)
from tf_crnn.config import Params
from tf_crnn.model import get_model_train
from tf_crnn.preprocessing import data_preprocessing
from tf_crnn.data_handler import dataset_generator
from tf_crnn.callbacks import CustomLoaderCallback, CustomSavingCallback, LRTensorBoard, EPOCH_FILENAME, FOLDER_SAVED_MODEL
import tensorflow as tf
import numpy as np
import os
import json
import pickle
from glob import glob
from sacred import Experiment, SETTINGS
SETTINGS.CONFIG.READ_ONLY_CONFIG = False
ex = Experiment('crnn')
ex.add_config('config.json')
@ex.automain
def training(_config: dict):
parameters = Params(**_config)
export_config_filename = os.path.join(parameters.output_model_dir, 'config.json')
saving_dir = os.path.join(parameters.output_model_dir, FOLDER_SAVED_MODEL)
if not parameters.restore_model:
# check if output folder already exists
assert not os.path.isdir(parameters.output_model_dir), \
'{} already exists, you cannot use it as output directory.'.format(parameters.output_model_dir)
# 'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(parameters.output_model_dir)
os.makedirs(parameters.output_model_dir)
# data and csv preprocessing
csv_train_file, csv_eval_file, \
n_samples_train, n_samples_eval = data_preprocessing(parameters)
# export config file in model output dir
with open(export_config_filename, 'w') as file:
json.dump(parameters.to_dict(), file)
# Create callbacks
logdir = os.path.join(parameters.output_model_dir, 'logs')
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=0)
lrtb_callback = LRTensorBoard(log_dir=logdir,
profile_batch=0)
lr_callback = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5,
patience=10,
cooldown=0,
min_lr=1e-8,
verbose=1)
es_callback = tf.keras.callbacks.EarlyStopping(min_delta=0.005,
patience=20,
verbose=1)
sv_callback = CustomSavingCallback(saving_dir,
saving_freq=parameters.save_interval,
save_best_only=True,
keep_max_models=3)
list_callbacks = [tb_callback, lrtb_callback, lr_callback, es_callback, sv_callback]
if parameters.restore_model:
last_time_stamp = max([int(p.split(os.path.sep)[-1].split('-')[0])
for p in glob(os.path.join(saving_dir, '*'))])
loading_dir = os.path.join(saving_dir, str(last_time_stamp))
ld_callback = CustomLoaderCallback(loading_dir)
list_callbacks.append(ld_callback)
with open(os.path.join(loading_dir, EPOCH_FILENAME), 'rb') as f:
initial_epoch = pickle.load(f)
epochs = initial_epoch + parameters.n_epochs
else:
initial_epoch = 0
epochs = parameters.n_epochs
# Get model
model = get_model_train(parameters)
# Get datasets
dataset_train = dataset_generator([csv_train_file],
parameters,
batch_size=parameters.train_batch_size,
data_augmentation=parameters.data_augmentation,
num_epochs=parameters.n_epochs)
dataset_eval = dataset_generator([csv_eval_file],
parameters,
batch_size=parameters.eval_batch_size,
data_augmentation=False,
num_epochs=parameters.n_epochs)
# Train model
model.fit(dataset_train,
epochs=epochs,
initial_epoch=initial_epoch,
steps_per_epoch=np.floor(n_samples_train / parameters.train_batch_size),
validation_data=dataset_eval,
validation_steps=np.floor(n_samples_eval / parameters.eval_batch_size),
callbacks=list_callbacks)