Skip to content

Commit

Permalink
Update ModelCheckpoint support ".h5" support (#20561)
Browse files Browse the repository at this point in the history
* Update ModelCheckpoint support ".h5" support

* ModelCheckpoint support ".h5" and ".keras" both filetype
  • Loading branch information
mehtamansi29 authored Nov 28, 2024
1 parent e0f61ee commit c8d7e0d
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions keras/src/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,14 @@ class ModelCheckpoint(Callback):
which will be filled the value of `epoch` and keys in `logs`
(passed in `on_epoch_end`).
The `filepath` name needs to end with `".weights.h5"` when
`save_weights_only=True` or should end with `".keras"` when
checkpoint saving the whole model (default).
`save_weights_only=True` or should end with `".keras"` or `".h5"`
when checkpoint saving the whole model (default).
For example:
if `filepath` is `"{epoch:02d}-{val_loss:.2f}.keras"`, then the
model checkpoints will be saved with the epoch number and the
validation loss in the filename. The directory of the filepath
should not be reused by any other callbacks to avoid conflicts.
if `filepath` is `"{epoch:02d}-{val_loss:.2f}.keras"` or
"{epoch:02d}-{val_loss:.2f}.h5"`, then the model checkpoints
will be saved with the epoch number and the validation loss in
the filename. The directory of the filepath should not be reused
by any other callbacks to avoid conflicts.
monitor: The metric name to monitor. Typically the metrics are set by
the `Model.compile` method. Note:
* Prefix the name with `"val_"` to monitor validation metrics.
Expand Down Expand Up @@ -187,10 +188,12 @@ def __init__(
f"filepath={self.filepath}"
)
else:
if not self.filepath.endswith(".keras"):
if not any(
self.filepath.endswith(ext) for ext in (".keras", ".h5")
):
raise ValueError(
"The filepath provided must end in `.keras` "
"(Keras model format). Received: "
"(Keras model format) or `.h5` (HDF5 format). Received: "
f"filepath={self.filepath}"
)

Expand Down

0 comments on commit c8d7e0d

Please sign in to comment.