Skip to content

Commit

Permalink
Update keras.saving for legacy in v2.11
Browse files Browse the repository at this point in the history
Update keras.saving for the packages having moved to legacy
  • Loading branch information
Shelnutt2 committed Jul 22, 2023
1 parent 10c8509 commit adc689a
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions tiledb/ml/models/tensorflow_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,23 @@
from ._base import Meta, TileDBArtifact, Timestamp

FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
TFOptimizer = keras.optimizers.TFOptimizer
get_json_type = keras.saving.saved_model.json_utils.get_json_type
preprocess_weights_for_loading = keras.saving.hdf5_format.preprocess_weights_for_loading
saving_utils = keras.saving.saving_utils
keras_major, keras_minor, keras_patch = keras.__version__.split(".")
# Handle keras <=v2.10
if int(keras_major) <= 2 and int(keras_minor) <= 10:
TFOptimizer = keras.optimizers.TFOptimizer
get_json_type = keras.saving.saved_model.json_utils.get_json_type
preprocess_weights_for_loading = (
keras.saving.hdf5_format.preprocess_weights_for_loading
)
saving_utils = keras.saving.saving_utils
# Handle keras >=v2.11
else:
TFOptimizer = tf.keras.optimizers.legacy.TFOptimizer
get_json_type = keras.saving.legacy.saved_model.json_utils.get_json_type
preprocess_weights_for_loading = (
keras.saving.legacy.hdf5_format.preprocess_weights_for_loading
)
saving_utils = keras.saving.legacy.saving_utils


class TensorflowKerasTileDBModel(TileDBArtifact[tf.keras.Model]):
Expand Down

0 comments on commit adc689a

Please sign in to comment.