-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implements H5 legacy saving for Keras Core #605
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
- Please remove the saved_model folder since we don't have saved_model logic yet and the json utils aren't saved_model related
- Please move the legacy logic to
keras_core/legacy/saving/
so that we consolidate all legacy code into a single location
from keras_core.saving.legacy import serialization | ||
from keras_core.utils.module_utils import tensorflow as tf | ||
|
||
if tf.available: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't do this at the top of the file, since it would cause a TF import when importing keras_core
regardless of whether the user is using it. In-line it where it gets used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, in-lined the import.
@@ -329,6 +329,24 @@ def learning_rate(self, learning_rate): | |||
) | |||
self._learning_rate.assign(learning_rate) | |||
|
|||
def set_weights(self, weights): | |||
"""Set the weights of the optimizer.""" | |||
if not getattr(self, "_built", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should be if not self.built
(I don't think we have a _built
attribute)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed to self.built, thanks!
def set_weights(self, weights): | ||
"""Set the weights of the optimizer.""" | ||
if not getattr(self, "_built", False): | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add tests for the 2 error messages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test in optimizer_test, fixed.
layer_data, custom_objects=custom_objects | ||
) | ||
if "module" not in layer_data: | ||
layer = saving_utils.model_from_config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment that this is for the legacy case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment added, fixed!
keras_core/models/functional.py
Outdated
@@ -384,7 +386,10 @@ def get_config(self): | |||
if node_data is not None: | |||
filtered_inbound_nodes.append(node_data) | |||
|
|||
layer_config = serialization_lib.serialize_keras_object(operation) | |||
serialize_obj_fn = serialization_lib.serialize_keras_object | |||
if getattr(self, "use_legacy_config", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this set to True? It seems somewhat unsafe to use a model-level attribute since the attribute will leak. I wonder if there could be a better way? A scope?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced use_legacy_config
attr with a scope in new file saving_options.py, which avoids circular imports and will be expanded with SavedModel options utilities.
keras_core/models/sequential.py
Outdated
@@ -258,13 +260,14 @@ def _is_layer_name_unique(self, layer): | |||
return True | |||
|
|||
def get_config(self): | |||
serialize_fn = serialization_lib.serialize_keras_object | |||
if getattr(self, "use_legacy_config", None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with scope, fixed. Thanks!
layer_config, | ||
custom_objects=custom_objects, | ||
) | ||
if "module" not in layer_config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! Please move the legacy code to keras_core/legacy/saving/
rather than keras_core/saving/legacy/
.
|
||
@contextlib.contextmanager | ||
def keras_option_scope(use_legacy_config=True): | ||
use_legacy_config_prev_value = _save_options_context.use_legacy_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can use global_state.set_global_attribute()
for this instead of a threading.local
subclass -- the benefit is that it centralizes global state management, making it easier to track
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed! I've changed it to a simpler scope using global state attributes. Thanks!
@@ -1,289 +0,0 @@ | |||
import warnings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove this file we should be able to also remove the entire folder (which only has this + legacy_h5_format_test.py
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
This PR implements the legacy H5 saving format for Keras Core, including legacy serialization. Additionally:
The next step, Keras SavedModel, relies heavily on this PR and will be the next major step.
This PR contains some basic tests, and more tests will be added in subsequent PRs.