Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX BUG in load_weights_from_hdf5_group_by_name" legacy_h5_format.py #20537

Merged
merged 2 commits into from
Nov 30, 2024

Conversation

edwardyehuang
Copy link
Contributor

@edwardyehuang edwardyehuang commented Nov 22, 2024

FIX #20536

@codecov-commenter
Copy link

codecov-commenter commented Nov 22, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.22%. Comparing base (5d36ee1) to head (972c7ca).
Report is 21 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20537      +/-   ##
==========================================
+ Coverage   82.15%   82.22%   +0.07%     
==========================================
  Files         515      515              
  Lines       47859    48166     +307     
  Branches     7494     7527      +33     
==========================================
+ Hits        39317    39604     +287     
- Misses       6730     6744      +14     
- Partials     1812     1818       +6     
Flag Coverage Δ
keras 82.07% <ø> (+0.07%) ⬆️
keras-jax 65.12% <ø> (+0.02%) ⬆️
keras-numpy 60.12% <ø> (+0.04%) ⬆️
keras-tensorflow 66.09% <ø> (+0.01%) ⬆️
keras-torch 65.06% <ø> (+0.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for the fix. Please add a unit test so we avoid breaking this in the future.

@edwardyehuang
Copy link
Contributor Author

Looks good, thanks for the fix. Please add a unit test so we avoid breaking this in the future.

It looks like I need to create a mock H5 data structure to test this function. Is there any other option? e.g. split the function?

@fchollet
Copy link
Collaborator

Can you add a unit test that targets the user-facing case in which the issue appeared?

@edwardyehuang
Copy link
Contributor Author

edwardyehuang commented Nov 24, 2024

Can you add a unit test that targets the user-facing case in which the issue appeared?

Yes, I can, I will do this in next weekend

def load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False):
"""Implements name-based weight loading (instead of topological loading).
Layers that have no matching name are skipped.
Args:
f: A pointer to a HDF5 group.
model: Model instance.
skip_mismatch: Boolean, whether to skip loading of layers
where there is a mismatch in the number of weights,
or a mismatch in the shape of the weights.
Raises:
ValueError: in case of mismatch between provided layers
and weights file and skip_match=False.
"""
if "keras_version" in f.attrs:
original_keras_version = f.attrs["keras_version"]

but the issue appeared in the function load_weights_from_hdf5_group_by_name that has two parameters: f(HDF5 pointer) and model. So I need to mock the HDF5 pointer and model.

A simple way to mock them is to create a mock model, save the weights in temporary files or memory, and then read them. However, I am not sure if this is a good practice for unit testing.

A complex method for achieving that is to directly mock the model and HDF5 data structure.

Additionally, another option is to split the symbolic_weights (line 522), where the issue appeared to a new function.

if "top_level_model_weights" in f:
symbolic_weights = model.trainable_weights + model.non_trainable_weights
weight_values = load_subset_weights_from_hdf5_group(

Which one do you prefer? @fchollet

@fchollet
Copy link
Collaborator

A simple way to mock them is to create a mock model, save the weights in temporary files or memory, and then read them. However, I am not sure if this is a good practice for unit testing.

This is fine. Create a very simple model that has the issue, save it to a temporary folder (see unit tests in keras/src/saving/saving_lib_test.py for examples), and read it.

@edwardyehuang
Copy link
Contributor Author

A simple way to mock them is to create a mock model, save the weights in temporary files or memory, and then read them. However, I am not sure if this is a good practice for unit testing.

This is fine. Create a very simple model that has the issue, save it to a temporary folder (see unit tests in keras/src/saving/saving_lib_test.py for examples), and read it.

I noticed that keras/src/legacy/saving/legacy_h5_format_test.py already has a mock subclassed model and corresponding tests DISABLED_test_subclassed_model_weights. I just added a top-level weight to this model.

However, the test DISABLED_test_subclassed_model_weights is marked as DISABLED. Does it need to be removed in this PR? @fchollet

@fchollet
Copy link
Collaborator

However, the test DISABLED_test_subclassed_model_weights is marked as DISABLED. Does it need to be removed in this PR?

I just undisabled the tests, check it out.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Nov 30, 2024
@fchollet fchollet merged commit 1412598 into keras-team:master Nov 30, 2024
5 of 6 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Nov 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

BUG in load_weights_from_hdf5_group_by_name
4 participants