-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
executable file
·129 lines (99 loc) · 4.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright 2020 Tuan Chien, James Diprose
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Author: Tuan Chien, James Diprose
import datetime
import os
import pathlib
import secrets
from timeit import default_timer as timer
import click
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
from ava_asd.config import get_optimiser, get_model, get_loss_weights, read_config
from ava_asd.generator import AvGenerator, DatasetSubset
from ava_asd.telegrambot import UpdateBot
from ava_asd.utils import set_gpu_memory_growth
def get_callbacks(data_path, sess_id, config, bot_config_file):
"""
Get a list of callbacks to use for training.
"""
# Get config values
mode = config['mode']
tb_logdir = config['tb_logdir']
save_best_only = config['save_best_only']
use_earlystopping = config['use_earlystopping']
es_patience = config['es_patience']
callbacks = []
# Model checkpoint
model_file_pattern = sess_id + '-' + mode + '-weights-{epoch:02d}-{val_main_out_accuracy:.4f}.hdf5'
experiment_path = os.path.join(data_path, 'experiments', sess_id)
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)
model_path = os.path.join(experiment_path, model_file_pattern)
callbacks.append(ModelCheckpoint(model_path, monitor='val_main_out_accuracy', verbose=1,
save_best_only=save_best_only, mode='max'))
# Tensorboard
tb_session_dir = os.path.join(tb_logdir, sess_id) # Puts the results in a unique TensorBoard session
pathlib.Path(tb_session_dir).mkdir(parents=True, exist_ok=True)
callbacks.append(TensorBoard(log_dir=tb_session_dir, update_freq='batch'))
# Early stopping
if use_earlystopping:
es_patience = es_patience
callbacks.append(EarlyStopping(monitor='val_main_out_loss', patience=es_patience))
# Telegram reporting bot
if bot_config_file is not None:
bot_config = read_config(bot_config_file.name)
callbacks.append(UpdateBot.from_dict(bot_config, sess_id=sess_id))
return callbacks
@click.command()
@click.argument('config-file', type=click.File('r'))
@click.argument('data-path', type=click.Path(exists=True, file_okay=False, dir_okay=True))
@click.option('--bot-config', type=click.File(), default=None)
def main(config_file, data_path, bot_config):
""" Train the audio visual model.
CONFIG_FILE: the config file with settings for the experiment.
DATA_PATH: the path to the folder with the data files.
"""
# Start time for measuring experiment
start = timer()
# Enable memory growth on GPU
set_gpu_memory_growth(True)
# Read configs
config = read_config(config_file.name)
# Load model
model, loss = get_model(config)
# Load data generators
train_gen = AvGenerator.from_dict(data_path, DatasetSubset.train, config)
test_gen = AvGenerator.from_dict(data_path, DatasetSubset.valid, config)
print(train_gen)
print(test_gen)
# Create list of callbacks to use for training
sess_id = secrets.token_urlsafe(5) # Create session id
callbacks = get_callbacks(data_path, sess_id, config, bot_config)
callbacks.append(train_gen)
callbacks.append(test_gen)
# Make optimiser and get loss weights
optimiser = get_optimiser(config)
loss_weights = get_loss_weights(config)
# Compile model
model.compile(loss=loss, optimizer=optimiser, metrics=['accuracy'], loss_weights=loss_weights)
# Dump a summary
model.summary()
# Run training
epochs = config['epochs']
model.fit(train_gen.dataset, epochs=epochs, validation_data=test_gen.dataset, callbacks=callbacks)
# Print duration
end = timer()
duration = end - start
print(f"Duration: {datetime.timedelta(seconds=duration)}")
if __name__ == "__main__":
main()