Skip to content

Commit

Permalink
changed tests to save to keras format.
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmpercussion committed Jun 20, 2024
1 parent 273f914 commit 7c5bfb9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions keras_mdn_layer/tests/test_mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ def test_save_mdn():
model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))
model.add(mdn.MDN(1, N_MIXES))
model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam())
model.save('test_save.h5', overwrite=True, save_format="h5")
m_2 = keras.models.load_model('test_save.h5', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)})
model.save('test_save.keras', overwrite=True)
m_2 = keras.models.load_model('test_save.keras', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)})
assert isinstance(m_2, keras.Sequential)

0 comments on commit 7c5bfb9

Please sign in to comment.