From 5b4f8e582a66bf6e1a49491eaba3c0668eb5f846 Mon Sep 17 00:00:00 2001 From: Seth Shelnutt Date: Sat, 22 Jul 2023 08:20:53 -0400 Subject: [PATCH] Support keras 2.11 changes to model weights --- tiledb/ml/models/tensorflow_keras.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tiledb/ml/models/tensorflow_keras.py b/tiledb/ml/models/tensorflow_keras.py index d5546310..31b50b04 100644 --- a/tiledb/ml/models/tensorflow_keras.py +++ b/tiledb/ml/models/tensorflow_keras.py @@ -279,6 +279,9 @@ def _serialize_optimizer_weights( assert self.artifact optimizer = self.artifact.optimizer if optimizer and not isinstance(optimizer, TFOptimizer): - optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights) + if hasattr(optimizer, "weights"): + optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights) + else: + optimizer_weights = [var.numpy() for var in optimizer.variables()] return pickle.dumps(optimizer_weights, protocol=4) return b""