-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
36 lines (29 loc) · 1.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import librosa.display
from keras.optimizers import Adam
from keras.optimizers.schedules import ExponentialDecay
from keras.losses import SparseCategoricalCrossentropy
from keras.metrics import SparseCategoricalAccuracy
from keras.callbacks import TensorBoard, ModelCheckpoint
from model import create_model
from data import make_dataset,split_data
model = create_model()
lr_schedule = ExponentialDecay(initial_learning_rate=1e-3,
decay_steps= 4000, decay_rate = 0.5, staircase=False)
model.compile(optimizer = Adam(lr_schedule),
loss = SparseCategoricalCrossentropy(from_logits = True),
metrics = [ SparseCategoricalAccuracy()])
# Creating callbacks
logdir = 'tb_logs/model'
tensorboard_callback = TensorBoard(log_dir=logdir)
checkpoint_filepath = 'model_checkpoint/model'
model_checkpoint_callback = ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=False,
monitor='val_sparse_categorical_accuracy',
mode='max',
save_best_only=True)
bird_filepaths_train, bird_filepaths_val, bird_labels_train, bird_labels_val = split_data()
dataset_train = make_dataset(bird_labels_train, bird_filepaths_train, shuffle=True)
dataset_val = make_dataset(bird_labels_val, bird_filepaths_val, shuffle=False)
training = model.fit(dataset_train, validation_data = dataset_val, epochs=120,
callbacks=[tensorboard_callback,model_checkpoint_callback])