Skip to content

Commit

Permalink
refactored training to use the PredictiveMusicMDRNN.train function
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmpercussion committed Jun 26, 2024
1 parent 71257cd commit 4f2cefe
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 148 deletions.
84 changes: 53 additions & 31 deletions impsy/mdrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import tensorflow as tf
import keras_mdn_layer as mdn
import time
import datetime

NET_MODE_TRAIN = "train"
NET_MODE_RUN = "run"
MODEL_DIR = "./models/"
LOG_PATH = "./logs/"
SCALE_FACTOR = 10 # scales input and output from the model. Should be the same between training and inference.

Expand Down Expand Up @@ -64,15 +64,14 @@ def __init__(
dimension=2,
n_hidden_units=128,
n_mixtures=5,
batch_size=100,
sequence_length=30,
layers=2,
):
"""Initialise the MDRNN model. Use mode='run' for evaluation graph and
mode='train' for training graph.
Keyword Arguments:
dimension : number of dimensions for the model = number of degrees of freedom + 1 (time)
n_hidden_units : number of LSTM units in each layer
n_mixtures : number of mixture components (5-10 is good)
Expand All @@ -86,9 +85,6 @@ def __init__(
self.n_hidden_units = n_hidden_units
self.n_rnn_layers = layers
self.n_mixtures = n_mixtures # number of mixtures
# Training parameters
self.batch_size = batch_size
self.val_split = 0.10
# Sampling hyperparameters
self.pi_temp = 1.5
self.sigma_temp = 0.01
Expand All @@ -107,15 +103,15 @@ def __init__(
self.run_name = self.get_run_name()
self.reset_lstm_states()


def build(self):
"""Builds the MDRNN model for training or inference.
"""
"""Builds the MDRNN model for training or inference."""
if self.inference:
state_input_output = True
else:
state_input_output = False
data_input = tf.keras.layers.Input(shape=(self.sequence_length, self.dimension), name="inputs")
data_input = tf.keras.layers.Input(
shape=(self.sequence_length, self.dimension), name="inputs"
)
lstm_in = data_input # starter input for lstm
state_inputs = [] # storage for LSTM state inputs
state_outputs = [] # storage for LSTM state outputs
Expand Down Expand Up @@ -156,7 +152,9 @@ def build(self):
# for training we don't need to keep track of state in the model
inputs = data_input
outputs = mdn_out
new_model = tf.keras.models.Model(inputs=inputs, outputs=outputs, name=self.model_name())
new_model = tf.keras.models.Model(
inputs=inputs, outputs=outputs, name=self.model_name()
)

if not self.inference:
# only need loss function and compile when training
Expand All @@ -166,7 +164,6 @@ def build(self):

return new_model


def reset_lstm_states(self):
states = []
for i in range(self.n_rnn_layers):
Expand Down Expand Up @@ -195,9 +192,9 @@ def model_name(self):
+ str(SCALE_FACTOR)
)

def load_model(self, model_file=None):
def load_model(self, model_file=None, model_dir="models"):
if model_file is None:
model_file = MODEL_DIR + self.model_name() + ".h5"
model_file = model_dir + "/" + self.model_name() + ".h5"
try:
self.model.load_weights(model_file)
except OSError as err:
Expand All @@ -210,44 +207,69 @@ def get_run_name(self):
out += time.strftime("%Y%m%d-%H%M%S")
return out

def train(self, X, y, num_epochs=10, saving=True):
"""Train the network for the a number of epochs."""
def train(
self,
X,
y,
batch_size=100,
epochs=10,
checkpointing=False,
early_stopping=True,
save_location="models",
validation_split=0.1,
patience=10,
logging=True,
):
"""Train the network for a number of epochs with a specific dataset."""
# Setup callbacks
filepath = MODEL_DIR + self.model_name() + "-E{epoch:02d}-VL{val_loss:.2f}.hdf5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath, monitor="val_loss", verbose=1, save_best_only=True, mode="min"
date_string = datetime.datetime.today().strftime("%Y%m%d-%H_%M_%S")
checkpoint_path = save_location + "/" + self.model_name() + "-ckpt.keras"
# checkpoint_path = save_location + "/" + model_name + "-E{epoch:02d}-VL{val_loss:.2f}.keras"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path,
monitor="val_loss",
verbose=1,
save_best_only=True,
mode="min",
)
terminateOnNaN = tf.keras.callbacks.TerminateOnNaN()
tboard = tf.keras.callbacks.TensorBoard(
log_dir=LOG_PATH + self.run_name,
histogram_freq=2,
batch_size=32,
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
monitor="val_loss", mode="min", verbose=1, patience=patience
)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=save_location + "/" + date_string + self.model_name(),
histogram_freq=0,
write_graph=True,
update_freq="epoch",
)
callbacks = [terminateOnNaN, tboard]
if saving:
callbacks.append(checkpoint)
callbacks = [terminateOnNaN]
if checkpointing:
callbacks.append(checkpoint_callback)
if early_stopping:
callbacks.append(early_stopping_callback)
if logging:
callbacks.append(tensorboard_callback)

# Do the data scaling in here.
X = np.array(X) * SCALE_FACTOR
y = np.array(y) * SCALE_FACTOR
print("Training corpus has shape:")

## print out stats.
print("Number of training examples:")
print("X:", X.shape)
print("y:", y.shape)

# Train
history = self.model.fit(
X,
y,
batch_size=self.batch_size,
epochs=num_epochs,
validation_split=self.val_split,
batch_size=batch_size,
epochs=epochs,
validation_split=validation_split,
callbacks=callbacks,
)
return history


def generate_touch(self, prev_sample):
"""Generate one forward prediction from a previous sample in format
(dt, x_1,...,x_n). Pi and Sigma temperature are adjustable."""
Expand Down
13 changes: 7 additions & 6 deletions impsy/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from impsy import mdrnn
from impsy import train
from impsy import train
from impsy import utils
import tensorflow as tf

Expand Down Expand Up @@ -34,18 +34,19 @@ def test_training():
dimension=dimension,
n_hidden_units=16,
n_mixtures=5,
batch_size=batch_size,
sequence_length=sequence_length,
layers=2,
)
x_t_log = utils.generate_data(samples=((sequence_length + 1) * batch_size), dimension=dimension)
x_t_log = utils.generate_data(
samples=((sequence_length + 1) * batch_size), dimension=dimension
)
slices = train.slice_sequence_examples(x_t_log, sequence_length + 1, step_size=1)
Xs, ys = train.seq_to_overlapping_format(slices)
history = net.train(Xs, ys, num_epochs=num_epochs, saving=False)
history = net.train(Xs, ys, batch_size=batch_size, epochs=num_epochs, logging=False)
assert isinstance(history, tf.keras.callbacks.History)


def test_model_config():
"""Tests the model config function."""
conf = utils.mdrnn_config('s')
assert(conf["units"] == 64)
conf = utils.mdrnn_config("s")
assert conf["units"] == 64
4 changes: 4 additions & 0 deletions impsy/tests/test_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from impsy import mdrnn
from impsy import train
from impsy import utils
import tensorflow as tf
Loading

0 comments on commit 4f2cefe

Please sign in to comment.