Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure backwards compatibility for legacy H5 saving format #682

Merged
merged 27 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2cbb158
Add saved_model_test
nkovela1 Jul 17, 2023
4a7a84e
Sync with main repo
nkovela1 Jul 20, 2023
f524725
Add extra saved model tests
nkovela1 Jul 21, 2023
8920368
Fix formatting
nkovela1 Jul 21, 2023
31013f5
Merge branch 'keras-team:main' into savedmodel
nkovela1 Jul 21, 2023
9842a2b
Merge remote-tracking branch 'refs/remotes/origin/h5' into h5
nkovela1 Jul 24, 2023
22a1d82
Add JSON utils for legacy saving
nkovela1 Jul 24, 2023
59e0426
Implement H5 saving for Keras core
nkovela1 Jul 25, 2023
e91f539
Change saving API routing
nkovela1 Jul 25, 2023
374fd33
Fix h5 format basic tests
nkovela1 Jul 25, 2023
3f60730
Ensure compile reload works with H5 format
nkovela1 Jul 25, 2023
7b49cfa
Remove useless comments
nkovela1 Jul 25, 2023
12eb4dc
Fix imports and formatting
nkovela1 Jul 25, 2023
690dc63
Move json_utils out of saved_model folder
nkovela1 Jul 25, 2023
a9929b1
Add test for set_weights in optimizer
nkovela1 Jul 25, 2023
ee19187
Fix formatting
nkovela1 Jul 25, 2023
e877c39
Add keras options scope to replace use_legacy_config attribute
nkovela1 Jul 25, 2023
2f809bc
Add options scope for legacy serialization, remove circular deps
nkovela1 Jul 25, 2023
1a61315
Add comments for legacy serialization code routing
nkovela1 Jul 25, 2023
3c670a7
Move saving/legacy to legacy/saving
nkovela1 Jul 26, 2023
c0cbe0a
Change keras options scope to use global state attr
nkovela1 Jul 26, 2023
0366717
Sync with main
nkovela1 Aug 7, 2023
079fa1e
Merge branch 'keras-team:main' into h5
nkovela1 Aug 8, 2023
a7f74b5
Add backwards compatibility for H5 saving
nkovela1 Aug 8, 2023
d416797
Fix formatting and add comments for helpers
nkovela1 Aug 8, 2023
46a1fab
Update helper comments
nkovela1 Aug 9, 2023
6f16fa7
Update helper comments
nkovela1 Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion keras_core/legacy/saving/legacy_h5_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_sequential_model(self):
ref_input = np.random.random((2, 3))
self._check_reloading_model(ref_input, model)

def test_functional_model_weights(self):
def test_functional_model(self):
model = get_functional_model(keras_core)
ref_input = np.random.random((2, 3))
self._check_reloading_model(ref_input, model)
Expand All @@ -121,3 +121,45 @@ def test_compiled_model_with_various_layers(self):
model.compile(optimizer="rmsprop", loss="mse")
ref_input = np.random.random((1, 3))
self._check_reloading_model(ref_input, model)


@pytest.mark.requires_trainable_backend
class LegacyH5BackwardsCompatTest(testing.TestCase):
def _check_reloading_model(self, ref_input, model, tf_keras_model):
# Whole model file
ref_output = tf_keras_model(ref_input)
temp_filepath = os.path.join(self.get_temp_dir(), "model.h5")
tf_keras_model.save(temp_filepath)
loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)
output = loaded(ref_input)
self.assertAllClose(ref_output, output, atol=1e-5)

def test_sequential_model(self):
model = get_sequential_model(keras_core)
tf_keras_model = get_sequential_model(tf.keras)
ref_input = np.random.random((2, 3))
self._check_reloading_model(ref_input, model, tf_keras_model)

def test_functional_model(self):
tf_keras_model = get_functional_model(tf.keras)
model = get_functional_model(keras_core)
ref_input = np.random.random((2, 3))
self._check_reloading_model(ref_input, model, tf_keras_model)

def test_compiled_model_with_various_layers(self):
model = models.Sequential()
model.add(layers.Dense(2, input_shape=(3,)))
model.add(layers.RepeatVector(3))
model.add(layers.TimeDistributed(layers.Dense(3)))
model.compile(optimizer="rmsprop", loss="mse")

tf_keras_model = tf.keras.Sequential()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future tf.keras will be Keras Core. So can we avoid depending on it, and instead depend on a static artifact? Maybe either downloaded or checked into git?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, however, for now as I develop and expand this testing suite I will keep this dependency for debugging purposes/changes. After the full testing suite is done (in the next PR or so), I will generate them as h5 artifacts and add them to that PR.

tf_keras_model.add(tf.keras.layers.Dense(2, input_shape=(3,)))
tf_keras_model.add(tf.keras.layers.RepeatVector(3))
tf_keras_model.add(
tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3))
)
tf_keras_model.compile(optimizer="rmsprop", loss="mse")

ref_input = np.random.random((1, 3))
self._check_reloading_model(ref_input, model, tf_keras_model)
50 changes: 48 additions & 2 deletions keras_core/legacy/saving/saving_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import threading

import tree
Expand All @@ -6,7 +7,7 @@
from keras_core import backend
from keras_core import layers
from keras_core import losses
from keras_core import metrics
from keras_core import metrics as metrics_module
from keras_core import models
from keras_core import optimizers
from keras_core.legacy.saving import serialization
Expand Down Expand Up @@ -46,6 +47,21 @@ def model_from_config(config, custom_objects=None):
MODULE_OBJECTS.ALL_OBJECTS["Model"] = models.Model
MODULE_OBJECTS.ALL_OBJECTS["Sequential"] = models.Sequential

batch_input_shape = config["config"].pop("batch_input_shape", None)
if batch_input_shape is not None:
if config["class_name"] == "InputLayer":
config["config"]["batch_shape"] = batch_input_shape
else:
config["config"]["input_shape"] = batch_input_shape

axis = config["config"].pop("axis", None)
if axis is not None and isinstance(axis, list) and len(axis) == 1:
config["config"]["axis"] = int(axis[0])

# TODO(nkovela): Swap find and replace args during Keras 3.0 release
# Replace keras refs with keras_core
config = _find_replace_nested_dict(config, "keras.", "keras_core.")

return serialization.deserialize_keras_object(
config,
module_objects=MODULE_OBJECTS.ALL_OBJECTS,
Expand Down Expand Up @@ -95,12 +111,18 @@ def compile_args_from_training_config(training_config, custom_objects=None):
with object_registration.CustomObjectScope(custom_objects):
optimizer_config = training_config["optimizer_config"]
optimizer = optimizers.deserialize(optimizer_config)
# Ensure backwards compatibility for optimizers in legacy H5 files
optimizer = _resolve_compile_arguments_compat(
optimizer, optimizer_config, optimizers
)

# Recover losses.
loss = None
loss_config = training_config.get("loss", None)
if loss_config is not None:
loss = _deserialize_nested_config(losses.deserialize, loss_config)
# Ensure backwards compatibility for losses in legacy H5 files
loss = _resolve_compile_arguments_compat(loss, loss_config, losses)

# Recover metrics.
metrics = None
Expand All @@ -109,6 +131,10 @@ def compile_args_from_training_config(training_config, custom_objects=None):
metrics = _deserialize_nested_config(
_deserialize_metric, metrics_config
)
# Ensure backwards compatibility for metrics in legacy H5 files
metrics = _resolve_compile_arguments_compat(
metrics, metrics_config, metrics_module
)

# Recover weighted metrics.
weighted_metrics = None
Expand Down Expand Up @@ -177,7 +203,27 @@ def _deserialize_metric(metric_config):
# special case handling for these in compile, based on model output
# shape.
return metric_config
return metrics.deserialize(metric_config)
return metrics_module.deserialize(metric_config)


def _find_replace_nested_dict(config, find, replace):
dict_str = json.dumps(config)
dict_str = dict_str.replace(find, replace)
config = json.loads(dict_str)
return config


def _resolve_compile_arguments_compat(obj, obj_config, module):
"""Resolves backwards compatiblity issues with training config arguments.

This helper function accepts built-in Keras modules such as optimizers,
losses, and metrics to ensure an object being deserialized is compatible
with Keras Core built-ins. For legacy H5 files saved within Keras Core,
this does nothing.
"""
if isinstance(obj, str) and obj not in module.ALL_OBJECTS_DICT:
obj = module.get(obj_config["config"]["name"])
return obj


def try_build_compiled_arguments(model):
Expand Down
15 changes: 14 additions & 1 deletion keras_core/legacy/saving/serialization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Legacy serialization logic for Keras models."""

import contextlib
import inspect
import json
import threading
import weakref

Expand Down Expand Up @@ -484,6 +484,12 @@ def deserialize(config, custom_objects=None):
arg_spec = inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}

# TODO(nkovela): Swap find and replace args during Keras 3.0 release
# Replace keras refs with keras_core
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the swap, the replacement should go the other way

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, fixed the TODO comment.

cls_config = _find_replace_nested_dict(
cls_config, "keras.", "keras_core."
)

if "custom_objects" in arg_spec.args:
deserialized_obj = cls.from_config(
cls_config,
Expand Down Expand Up @@ -558,3 +564,10 @@ def validate_config(config):
def is_default(method):
"""Check if a method is decorated with the `default` wrapper."""
return getattr(method, "_is_default", False)


def _find_replace_nested_dict(config, find, replace):
dict_str = json.dumps(config)
dict_str = dict_str.replace(find, replace)
config = json.loads(dict_str)
return config