Skip to content

Commit

Permalink
Support keras 2.11 changes to model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Shelnutt2 committed Jul 22, 2023
1 parent 23651ed commit 5b4f8e5
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tiledb/ml/models/tensorflow_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ def _serialize_optimizer_weights(
assert self.artifact
optimizer = self.artifact.optimizer
if optimizer and not isinstance(optimizer, TFOptimizer):
optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights)
if hasattr(optimizer, "weights"):
optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights)
else:
optimizer_weights = [var.numpy() for var in optimizer.variables()]
return pickle.dumps(optimizer_weights, protocol=4)
return b""

0 comments on commit 5b4f8e5

Please sign in to comment.