diff --git a/tiledb/ml/models/tensorflow_keras.py b/tiledb/ml/models/tensorflow_keras.py index d5546310..31b50b04 100644 --- a/tiledb/ml/models/tensorflow_keras.py +++ b/tiledb/ml/models/tensorflow_keras.py @@ -279,6 +279,9 @@ def _serialize_optimizer_weights( assert self.artifact optimizer = self.artifact.optimizer if optimizer and not isinstance(optimizer, TFOptimizer): - optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights) + if hasattr(optimizer, "weights"): + optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights) + else: + optimizer_weights = [var.numpy() for var in optimizer.variables()] return pickle.dumps(optimizer_weights, protocol=4) return b""