diff --git a/keras_mdn_layer/tests/test_mdn.py b/keras_mdn_layer/tests/test_mdn.py index 8fdb2b5..262cc5b 100644 --- a/keras_mdn_layer/tests/test_mdn.py +++ b/keras_mdn_layer/tests/test_mdn.py @@ -37,6 +37,6 @@ def test_save_mdn(): model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu')) model.add(mdn.MDN(1, N_MIXES)) model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam()) - model.save('test_save.h5', overwrite=True, save_format="h5") - m_2 = keras.models.load_model('test_save.h5', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)}) + model.save('test_save.keras', overwrite=True) + m_2 = keras.models.load_model('test_save.keras', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)}) assert isinstance(m_2, keras.Sequential)