From 409375d8efe9f8a7e183e1213b80eac1a3cab385 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 12 Jul 2023 13:43:52 -0700 Subject: [PATCH] Minor performance optimizations for eager. --- keras_core/backend/torch/trainer.py | 8 ++--- keras_core/optimizers/base_optimizer.py | 44 ++++++++++++++----------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/keras_core/backend/torch/trainer.py b/keras_core/backend/torch/trainer.py index 4a214c97e..e46d41378 100644 --- a/keras_core/backend/torch/trainer.py +++ b/keras_core/backend/torch/trainer.py @@ -43,20 +43,16 @@ def train_step(self, data): # Compute gradients if self.trainable_weights: - # Backpropagation - trainable_weights = [v for v in self.trainable_weights] - # Call torch.Tensor.backward() on the loss to compute gradients # for the weights. loss.backward() + trainable_weights = self.trainable_weights[:] gradients = [v.value.grad for v in trainable_weights] # Update weights with torch.no_grad(): - self.optimizer.apply_gradients( - zip(gradients, trainable_weights) - ) + self.optimizer.apply(gradients, trainable_weights) else: warnings.warn("The model does not have any trainable weights.") diff --git a/keras_core/optimizers/base_optimizer.py b/keras_core/optimizers/base_optimizer.py index d8f107ec1..192356eb8 100644 --- a/keras_core/optimizers/base_optimizer.py +++ b/keras_core/optimizers/base_optimizer.py @@ -193,7 +193,6 @@ def apply(self, grads, trainable_variables=None): `variables` can be provided on the first call to build the optimizer. """ - grads = list(grads) if len(grads) == 0: # It is possible that the grad is empty. In this case, # `apply_gradients` is a no-op. @@ -224,16 +223,15 @@ def apply(self, grads, trainable_variables=None): self.built = True self._check_variables_are_known(trainable_variables) - grads_and_vars = list(zip(grads, self._trainable_variables)) - with ops.name_scope(self.name): # Filter empty gradients. - grads_and_vars = self._filter_empty_gradients(grads_and_vars) - if len(list(grads_and_vars)) == 0: + grads, trainable_variables = self._filter_empty_gradients( + grads, trainable_variables + ) + if len(list(grads)) == 0: return # Apply clipping and weight decay. - grads, trainable_variables = zip(*grads_and_vars) grads = self._clip_gradients(grads) self._apply_weight_decay(trainable_variables) @@ -363,19 +361,27 @@ def _get_current_learning_rate(self): return self._learning_rate(self.iterations) return self._learning_rate - def _filter_empty_gradients(self, grads_and_vars): - filtered = [(g, v) for g, v in grads_and_vars if g is not None] - if not filtered: - raise ValueError("No gradients provided for any variable.") - if len(filtered) < len(grads_and_vars): - missing_grad_vars = [v for g, v in grads_and_vars if g is None] - warnings.warn( - "Gradients do not exist for variables " - f"{[v.name for v in missing_grad_vars]} when minimizing the " - "loss. If you're using `model.compile()`, did you forget to " - "provide a `loss` argument?" - ) - return filtered + def _filter_empty_gradients(self, grads, vars): + for grad in grads: + if grad is None: + # Filtering is required. + filtered = [ + (g, v) for g, v in zip(grads, vars) if g is not None + ] + if not filtered: + raise ValueError("No gradients provided for any variable.") + if len(filtered) < len(grads): + missing_grad_vars = [ + v for g, v in zip(grads, vars) if g is None + ] + warnings.warn( + "Gradients do not exist for variables " + f"{[v.name for v in missing_grad_vars]} when " + "minimizing the loss. If using `model.compile()`, " + "did you forget to provide a `loss` argument?" + ) + return zip(*filtered) + return grads, vars def _clip_gradients(self, grads): if self.clipnorm and self.clipnorm > 0: