From 7c5bfb9290124c159cea918bb7f696def1a9e905 Mon Sep 17 00:00:00 2001 From: Charles Martin Date: Thu, 20 Jun 2024 13:26:58 +1000 Subject: [PATCH] changed tests to save to keras format. --- keras_mdn_layer/tests/test_mdn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_mdn_layer/tests/test_mdn.py b/keras_mdn_layer/tests/test_mdn.py index 8fdb2b5..262cc5b 100644 --- a/keras_mdn_layer/tests/test_mdn.py +++ b/keras_mdn_layer/tests/test_mdn.py @@ -37,6 +37,6 @@ def test_save_mdn(): model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu')) model.add(mdn.MDN(1, N_MIXES)) model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam()) - model.save('test_save.h5', overwrite=True, save_format="h5") - m_2 = keras.models.load_model('test_save.h5', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)}) + model.save('test_save.keras', overwrite=True) + m_2 = keras.models.load_model('test_save.keras', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)}) assert isinstance(m_2, keras.Sequential)