diff --git a/keras_core/optimizers/base_optimizer.py b/keras_core/optimizers/base_optimizer.py index 064d053b1..850ba0e1b 100644 --- a/keras_core/optimizers/base_optimizer.py +++ b/keras_core/optimizers/base_optimizer.py @@ -201,7 +201,9 @@ def update_step(self, gradient, variable, learning_rate): def apply_gradients(self, grads_and_vars): grads, trainable_variables = zip(*grads_and_vars) - return self.apply(grads, trainable_variables) + self.apply(grads, trainable_variables) + # Return iterations for compat with tf.keras. + return self.iterations def apply(self, grads, trainable_variables=None): """ @@ -261,7 +263,6 @@ def apply(self, grads, trainable_variables=None): for variable in trainable_variables: if getattr(variable, "constraint", None) is not None: variable.assign(variable.constraint(variable)) - return self.iterations def _internal_apply_gradients(self, grads_and_vars): learning_rate = self._get_current_learning_rate()