Skip to content

Commit

Permalink
Remove optimizer_connector.py (Lightning-AI#10120)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellepintz authored and ninginthecloud committed Oct 27, 2021
1 parent 0d2df30 commit 2051277
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 104 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed automatic patching of `{train,val,test,predict}_dataloader()` on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))


- Removed `pytorch_lightning.trainer.connectors.OptimizerConnector` ([#10120](https://github.com/PyTorchLightning/pytorch-lightning/pull/10120))


### Fixed


Expand Down
69 changes: 68 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
Expand Down Expand Up @@ -443,12 +444,78 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -
active_optimizers = _get_active_optimizers(
self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx
)
self.trainer.optimizer_connector.update_learning_rates(
self._update_learning_rates(
interval=interval,
update_plateau_schedulers=update_plateau_schedulers,
opt_indices=[opt_idx for opt_idx, _ in active_optimizers],
)

def _update_learning_rates(
self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None
) -> None:
"""Update learning rates.
Args:
interval: either 'epoch' or 'step'.
update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated.
This is used so non-plateau schedulers can be updated before running validation. Checkpoints are
commonly saved during validation, however, on-plateau schedulers might monitor a validation metric
so they have to be updated separately.
opt_indices: indices of the optimizers to update.
"""
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
return

if opt_indices is None:
opt_indices = []

for lr_scheduler in self.trainer.lr_schedulers:
if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices:
continue

if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:
continue

current_idx = self.trainer.fit_loop.batch_idx if interval == "step" else self.trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0:
monitor_val = None
if lr_scheduler["reduce_on_plateau"]:
# If instance of ReduceLROnPlateau, we need a monitor
monitor_key = lr_scheduler["monitor"]
monitor_val = self._get_monitor_value(monitor_key)
if monitor_val is None:
if lr_scheduler.get("strict", True):
avail_metrics = list(self.trainer.callback_metrics)
raise MisconfigurationException(
f"ReduceLROnPlateau conditioned on metric {monitor_key}"
f" which is not available. Available metrics are: {avail_metrics}."
" Condition can be set using `monitor` key in lr scheduler dict"
)
rank_zero_warn(
f"ReduceLROnPlateau conditioned on metric {monitor_key}"
" which is not available but strict is set to `False`."
" Skipping learning rate update.",
RuntimeWarning,
)
continue

self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready()

# update LR
if lr_scheduler["reduce_on_plateau"]:
lr_scheduler["scheduler"].step(monitor_val)
else:
lr_scheduler["scheduler"].step()

self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed()

def _get_monitor_value(self, key: str) -> Any:
# this is a separate method to aid in testing
return self.trainer.callback_metrics.get(key)

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
"""Decide if we should run validation."""
if not self.trainer.enable_validation:
Expand Down
95 changes: 0 additions & 95 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py

This file was deleted.

6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
Expand Down Expand Up @@ -430,7 +429,6 @@ def __init__(

# init connectors
self._data_connector = DataConnector(self, multiple_trainloader_mode)
self.optimizer_connector = OptimizerConnector(self)

self._accelerator_connector = AcceleratorConnector(
num_processes,
Expand Down Expand Up @@ -517,7 +515,9 @@ def __init__(
self.on_init_start()

# init optimizer + lr scheduler related flags
self.optimizer_connector.on_trainer_init()
self.lr_schedulers = []
self.optimizers = []
self.optimizer_frequencies = []

# init data flags
self._data_connector.on_trainer_init(
Expand Down
10 changes: 5 additions & 5 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ def validation_epoch_end(self, outputs):
self.log("val_acc", outs)


def mock_optimizer_connector(trainer):
def mock_training_epoch_loop(trainer):
# do not use `unittest.Mock` because we need to store the return value
calls = {}
old_get_monitor_value = trainer.optimizer_connector._get_monitor_value
old_get_monitor_value = trainer.fit_loop.epoch_loop._get_monitor_value

def mock(key):
value = old_get_monitor_value(key)
calls[trainer.current_epoch] = {key: value}
return value

trainer.optimizer_connector._get_monitor_value = mock
trainer.fit_loop.epoch_loop._get_monitor_value = mock
return calls


Expand Down Expand Up @@ -150,7 +150,7 @@ def on_validation_epoch_end(self):
max_epochs=max_epochs,
enable_progress_bar=False,
)
calls = mock_optimizer_connector(trainer)
calls = mock_training_epoch_loop(trainer)
trainer.fit(model)

ckpt_files = list(Path(tmpdir).glob("*.ckpt"))
Expand Down Expand Up @@ -248,7 +248,7 @@ def configure_optimizers(self):
enable_progress_bar=False,
num_sanity_val_steps=0,
)
calls = mock_optimizer_connector(trainer)
calls = mock_training_epoch_loop(trainer)
trainer.fit(model)

def _make_assertions(epoch, ix):
Expand Down

0 comments on commit 2051277

Please sign in to comment.