Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements H5 legacy saving for Keras Core #605

Merged
merged 21 commits into from
Jul 26, 2023
Merged

Conversation

nkovela1
Copy link
Collaborator

This PR implements the legacy H5 saving format for Keras Core, including legacy serialization. Additionally:

  • Code routing in saving API file for H5 format
  • JSON utils for legacy saving
  • Metadata utils for legacy saving
  • Add set_weights method for optimizer

The next step, Keras SavedModel, relies heavily on this PR and will be the next major step.
This PR contains some basic tests, and more tests will be added in subsequent PRs.

@nkovela1 nkovela1 requested a review from fchollet July 25, 2023 21:19
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

  • Please remove the saved_model folder since we don't have saved_model logic yet and the json utils aren't saved_model related
  • Please move the legacy logic to keras_core/legacy/saving/ so that we consolidate all legacy code into a single location

from keras_core.saving.legacy import serialization
from keras_core.utils.module_utils import tensorflow as tf

if tf.available:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't do this at the top of the file, since it would cause a TF import when importing keras_core regardless of whether the user is using it. In-line it where it gets used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, in-lined the import.

@@ -329,6 +329,24 @@ def learning_rate(self, learning_rate):
)
self._learning_rate.assign(learning_rate)

def set_weights(self, weights):
"""Set the weights of the optimizer."""
if not getattr(self, "_built", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be if not self.built (I don't think we have a _built attribute)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed to self.built, thanks!

def set_weights(self, weights):
"""Set the weights of the optimizer."""
if not getattr(self, "_built", False):
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests for the 2 error messages

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test in optimizer_test, fixed.

layer_data, custom_objects=custom_objects
)
if "module" not in layer_data:
layer = saving_utils.model_from_config(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment that this is for the legacy case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added, fixed!

@@ -384,7 +386,10 @@ def get_config(self):
if node_data is not None:
filtered_inbound_nodes.append(node_data)

layer_config = serialization_lib.serialize_keras_object(operation)
serialize_obj_fn = serialization_lib.serialize_keras_object
if getattr(self, "use_legacy_config", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this set to True? It seems somewhat unsafe to use a model-level attribute since the attribute will leak. I wonder if there could be a better way? A scope?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced use_legacy_config attr with a scope in new file saving_options.py, which avoids circular imports and will be expanded with SavedModel options utilities.

@@ -258,13 +260,14 @@ def _is_layer_name_unique(self, layer):
return True

def get_config(self):
serialize_fn = serialization_lib.serialize_keras_object
if getattr(self, "use_legacy_config", None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with scope, fixed. Thanks!

layer_config,
custom_objects=custom_objects,
)
if "module" not in layer_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! Please move the legacy code to keras_core/legacy/saving/ rather than keras_core/saving/legacy/.


@contextlib.contextmanager
def keras_option_scope(use_legacy_config=True):
use_legacy_config_prev_value = _save_options_context.use_legacy_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use global_state.set_global_attribute() for this instead of a threading.local subclass -- the benefit is that it centralizes global state management, making it easier to track

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! I've changed it to a simpler scope using global state attributes. Thanks!

@@ -1,289 +0,0 @@
import warnings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove this file we should be able to also remove the entire folder (which only has this + legacy_h5_format_test.py)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@fchollet fchollet merged commit 18a174b into keras-team:main Jul 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants