diff --git a/keras_core/legacy/saving/legacy_h5_format_test.py b/keras_core/legacy/saving/legacy_h5_format_test.py index 9a0e39d41..02808a4ac 100644 --- a/keras_core/legacy/saving/legacy_h5_format_test.py +++ b/keras_core/legacy/saving/legacy_h5_format_test.py @@ -107,7 +107,7 @@ def test_sequential_model(self): ref_input = np.random.random((2, 3)) self._check_reloading_model(ref_input, model) - def test_functional_model_weights(self): + def test_functional_model(self): model = get_functional_model(keras_core) ref_input = np.random.random((2, 3)) self._check_reloading_model(ref_input, model) @@ -121,3 +121,45 @@ def test_compiled_model_with_various_layers(self): model.compile(optimizer="rmsprop", loss="mse") ref_input = np.random.random((1, 3)) self._check_reloading_model(ref_input, model) + + +@pytest.mark.requires_trainable_backend +class LegacyH5BackwardsCompatTest(testing.TestCase): + def _check_reloading_model(self, ref_input, model, tf_keras_model): + # Whole model file + ref_output = tf_keras_model(ref_input) + temp_filepath = os.path.join(self.get_temp_dir(), "model.h5") + tf_keras_model.save(temp_filepath) + loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath) + output = loaded(ref_input) + self.assertAllClose(ref_output, output, atol=1e-5) + + def test_sequential_model(self): + model = get_sequential_model(keras_core) + tf_keras_model = get_sequential_model(tf.keras) + ref_input = np.random.random((2, 3)) + self._check_reloading_model(ref_input, model, tf_keras_model) + + def test_functional_model(self): + tf_keras_model = get_functional_model(tf.keras) + model = get_functional_model(keras_core) + ref_input = np.random.random((2, 3)) + self._check_reloading_model(ref_input, model, tf_keras_model) + + def test_compiled_model_with_various_layers(self): + model = models.Sequential() + model.add(layers.Dense(2, input_shape=(3,))) + model.add(layers.RepeatVector(3)) + model.add(layers.TimeDistributed(layers.Dense(3))) + model.compile(optimizer="rmsprop", loss="mse") + + tf_keras_model = tf.keras.Sequential() + tf_keras_model.add(tf.keras.layers.Dense(2, input_shape=(3,))) + tf_keras_model.add(tf.keras.layers.RepeatVector(3)) + tf_keras_model.add( + tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)) + ) + tf_keras_model.compile(optimizer="rmsprop", loss="mse") + + ref_input = np.random.random((1, 3)) + self._check_reloading_model(ref_input, model, tf_keras_model) diff --git a/keras_core/legacy/saving/saving_utils.py b/keras_core/legacy/saving/saving_utils.py index faa486486..8aac9977a 100644 --- a/keras_core/legacy/saving/saving_utils.py +++ b/keras_core/legacy/saving/saving_utils.py @@ -1,3 +1,4 @@ +import json import threading import tree @@ -6,7 +7,7 @@ from keras_core import backend from keras_core import layers from keras_core import losses -from keras_core import metrics +from keras_core import metrics as metrics_module from keras_core import models from keras_core import optimizers from keras_core.legacy.saving import serialization @@ -46,6 +47,21 @@ def model_from_config(config, custom_objects=None): MODULE_OBJECTS.ALL_OBJECTS["Model"] = models.Model MODULE_OBJECTS.ALL_OBJECTS["Sequential"] = models.Sequential + batch_input_shape = config["config"].pop("batch_input_shape", None) + if batch_input_shape is not None: + if config["class_name"] == "InputLayer": + config["config"]["batch_shape"] = batch_input_shape + else: + config["config"]["input_shape"] = batch_input_shape + + axis = config["config"].pop("axis", None) + if axis is not None and isinstance(axis, list) and len(axis) == 1: + config["config"]["axis"] = int(axis[0]) + + # TODO(nkovela): Swap find and replace args during Keras 3.0 release + # Replace keras refs with keras_core + config = _find_replace_nested_dict(config, "keras.", "keras_core.") + return serialization.deserialize_keras_object( config, module_objects=MODULE_OBJECTS.ALL_OBJECTS, @@ -95,12 +111,18 @@ def compile_args_from_training_config(training_config, custom_objects=None): with object_registration.CustomObjectScope(custom_objects): optimizer_config = training_config["optimizer_config"] optimizer = optimizers.deserialize(optimizer_config) + # Ensure backwards compatibility for optimizers in legacy H5 files + optimizer = _resolve_compile_arguments_compat( + optimizer, optimizer_config, optimizers + ) # Recover losses. loss = None loss_config = training_config.get("loss", None) if loss_config is not None: loss = _deserialize_nested_config(losses.deserialize, loss_config) + # Ensure backwards compatibility for losses in legacy H5 files + loss = _resolve_compile_arguments_compat(loss, loss_config, losses) # Recover metrics. metrics = None @@ -109,6 +131,10 @@ def compile_args_from_training_config(training_config, custom_objects=None): metrics = _deserialize_nested_config( _deserialize_metric, metrics_config ) + # Ensure backwards compatibility for metrics in legacy H5 files + metrics = _resolve_compile_arguments_compat( + metrics, metrics_config, metrics_module + ) # Recover weighted metrics. weighted_metrics = None @@ -177,7 +203,27 @@ def _deserialize_metric(metric_config): # special case handling for these in compile, based on model output # shape. return metric_config - return metrics.deserialize(metric_config) + return metrics_module.deserialize(metric_config) + + +def _find_replace_nested_dict(config, find, replace): + dict_str = json.dumps(config) + dict_str = dict_str.replace(find, replace) + config = json.loads(dict_str) + return config + + +def _resolve_compile_arguments_compat(obj, obj_config, module): + """Resolves backwards compatiblity issues with training config arguments. + + This helper function accepts built-in Keras modules such as optimizers, + losses, and metrics to ensure an object being deserialized is compatible + with Keras Core built-ins. For legacy H5 files saved within Keras Core, + this does nothing. + """ + if isinstance(obj, str) and obj not in module.ALL_OBJECTS_DICT: + obj = module.get(obj_config["config"]["name"]) + return obj def try_build_compiled_arguments(model): diff --git a/keras_core/legacy/saving/serialization.py b/keras_core/legacy/saving/serialization.py index f19c84893..2da6f7f93 100644 --- a/keras_core/legacy/saving/serialization.py +++ b/keras_core/legacy/saving/serialization.py @@ -1,7 +1,7 @@ """Legacy serialization logic for Keras models.""" - import contextlib import inspect +import json import threading import weakref @@ -484,6 +484,12 @@ def deserialize(config, custom_objects=None): arg_spec = inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} + # TODO(nkovela): Swap find and replace args during Keras 3.0 release + # Replace keras refs with keras_core + cls_config = _find_replace_nested_dict( + cls_config, "keras.", "keras_core." + ) + if "custom_objects" in arg_spec.args: deserialized_obj = cls.from_config( cls_config, @@ -558,3 +564,10 @@ def validate_config(config): def is_default(method): """Check if a method is decorated with the `default` wrapper.""" return getattr(method, "_is_default", False) + + +def _find_replace_nested_dict(config, find, replace): + dict_str = json.dumps(config) + dict_str = dict_str.replace(find, replace) + config = json.loads(dict_str) + return config