Skip to content

Commit

Permalink
Minor performance optimizations for eager.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jul 12, 2023
1 parent c95bf43 commit 409375d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 25 deletions.
8 changes: 2 additions & 6 deletions keras_core/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,16 @@ def train_step(self, data):

# Compute gradients
if self.trainable_weights:
# Backpropagation
trainable_weights = [v for v in self.trainable_weights]

# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()

trainable_weights = self.trainable_weights[:]
gradients = [v.value.grad for v in trainable_weights]

# Update weights
with torch.no_grad():
self.optimizer.apply_gradients(
zip(gradients, trainable_weights)
)
self.optimizer.apply(gradients, trainable_weights)
else:
warnings.warn("The model does not have any trainable weights.")

Expand Down
44 changes: 25 additions & 19 deletions keras_core/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def apply(self, grads, trainable_variables=None):
`variables` can be provided on the first call to build the optimizer.
"""
grads = list(grads)
if len(grads) == 0:
# It is possible that the grad is empty. In this case,
# `apply_gradients` is a no-op.
Expand Down Expand Up @@ -224,16 +223,15 @@ def apply(self, grads, trainable_variables=None):
self.built = True
self._check_variables_are_known(trainable_variables)

grads_and_vars = list(zip(grads, self._trainable_variables))

with ops.name_scope(self.name):
# Filter empty gradients.
grads_and_vars = self._filter_empty_gradients(grads_and_vars)
if len(list(grads_and_vars)) == 0:
grads, trainable_variables = self._filter_empty_gradients(
grads, trainable_variables
)
if len(list(grads)) == 0:
return

# Apply clipping and weight decay.
grads, trainable_variables = zip(*grads_and_vars)
grads = self._clip_gradients(grads)
self._apply_weight_decay(trainable_variables)

Expand Down Expand Up @@ -363,19 +361,27 @@ def _get_current_learning_rate(self):
return self._learning_rate(self.iterations)
return self._learning_rate

def _filter_empty_gradients(self, grads_and_vars):
filtered = [(g, v) for g, v in grads_and_vars if g is not None]
if not filtered:
raise ValueError("No gradients provided for any variable.")
if len(filtered) < len(grads_and_vars):
missing_grad_vars = [v for g, v in grads_and_vars if g is None]
warnings.warn(
"Gradients do not exist for variables "
f"{[v.name for v in missing_grad_vars]} when minimizing the "
"loss. If you're using `model.compile()`, did you forget to "
"provide a `loss` argument?"
)
return filtered
def _filter_empty_gradients(self, grads, vars):
for grad in grads:
if grad is None:
# Filtering is required.
filtered = [
(g, v) for g, v in zip(grads, vars) if g is not None
]
if not filtered:
raise ValueError("No gradients provided for any variable.")
if len(filtered) < len(grads):
missing_grad_vars = [
v for g, v in zip(grads, vars) if g is None
]
warnings.warn(
"Gradients do not exist for variables "
f"{[v.name for v in missing_grad_vars]} when "
"minimizing the loss. If using `model.compile()`, "
"did you forget to provide a `loss` argument?"
)
return zip(*filtered)
return grads, vars

def _clip_gradients(self, grads):
if self.clipnorm and self.clipnorm > 0:
Expand Down

0 comments on commit 409375d

Please sign in to comment.