Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small code simplification in training_epoch_loop.py #10146

Merged
merged 8 commits into from
Oct 26, 2021
6 changes: 3 additions & 3 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def _update_learning_rates(
if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:
continue

current_idx = self.trainer.fit_loop.batch_idx if interval == "step" else self.trainer.current_epoch
current_idx = self.batch_idx if interval == "step" else self.trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
Expand All @@ -502,15 +502,15 @@ def _update_learning_rates(
)
continue

self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready()
self.scheduler_progress.increment_ready()

# update LR
if lr_scheduler["reduce_on_plateau"]:
lr_scheduler["scheduler"].step(monitor_val)
else:
lr_scheduler["scheduler"].step()

self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed()
self.scheduler_progress.increment_completed()

def _get_monitor_value(self, key: str) -> Any:
# this is a separate method to aid in testing
Expand Down