From d2488d48ab3b785f961c0e5b173b30ff8b89a176 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 21 Dec 2021 00:48:08 +0530 Subject: [PATCH] address reviews --- docs/source/common/optimizers.rst | 24 +++++++++---------- pytorch_lightning/core/lightning.py | 14 +++++------ .../loops/epoch/training_epoch_loop.py | 4 ++-- tests/models/test_hooks.py | 2 +- tests/trainer/optimization/test_optimizers.py | 4 ++-- 5 files changed, 23 insertions(+), 25 deletions(-) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 25245433b1eee..8fd26a22dadc1 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -254,25 +254,23 @@ If you want to call schedulers that require a metric value after each epoch, con Bring your own Custom Learning Rate Schedulers ---------------------------------------------- -Lightning allows custom learning rate schedulers which are not present in -`PyTorch natively `_. -One good example is `Timm Schedulers `_. -You can configure how your learning rate will be updated based on your custom implementation -and lightning will handle when they should be updated based on the scheduler config provided inside -:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers`. For you custom -implementation you must override :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` -if necessary. If you are using native PyTorch schedulers, there is no need to override this hook since -Lightning will handle it optimally by default. +Lightning allows using custom learning rate schedulers that aren't available in `PyTorch natively `_. +One good example is `Timm Schedulers `_. When using custom learning rate schedulers +relying on a different API from Native PyTorch ones, you should override the :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` with your desired logic. +If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it optimally by default. + +.. code-block:: python + + from timm.scheduler import TanhLRScheduler -.. testcode:: python def configure_optimizers(self): optimizer = ... - scheduler = ... - return [optimizer], [scheduler] + scheduler = TanhLRScheduler(optimizer, ...) + return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] - def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val=None): + def lr_scheduler_step(self, scheduler, optimizer_idx, metrics=None): scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value ----- diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3245706b13846..1b63780d827fa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1177,7 +1177,7 @@ def configure_optimizers(self): # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, - # Metric to to monitor for schedulers like ``ReduceLROnPlateau`` + # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping @@ -1497,8 +1497,8 @@ def lr_scheduler_step( self, scheduler: Any, optimizer_idx: Optional[int] = None, - monitor_val: Optional[Union[float, torch.Tensor]] = None, - ): + metrics: Optional[Union[float, torch.Tensor]] = None, + ) -> None: r""" Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler. @@ -1513,21 +1513,21 @@ def lr_scheduler_step( Examples:: # DEFAULT - def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val): + def lr_scheduler_step(self, scheduler, optimizer_idx, metrics): if monitor_val is None: scheduler.step() else: scheduler.step(monitor_val) # Alternative way to update schedulers if it requires an epoch value - def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val): + def lr_scheduler_step(self, scheduler, optimizer_idx, metrics): scheduler.step(epoch=self.current_epoch) """ - if monitor_val is None: + if metrics is None: scheduler.step() else: - scheduler.step(metrics=monitor_val) + scheduler.step(metrics=metrics) def optimizer_step( self, diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 751eb64d1a8cb..6ded01c90f8e4 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -509,11 +509,11 @@ def _update_learning_rates( "lr_scheduler_step", lr_scheduler["scheduler"], optimizer_idx=lr_scheduler["opt_idx"], - monitor_val=monitor_val, + metrics=monitor_val, ) self.scheduler_progress.increment_completed() - def _get_monitor_value(self, key: str) -> Optional[Union[float, torch.Tensor]]: + def _get_monitor_value(self, key: str) -> Any: # this is a separate method to aid in testing return self.trainer.callback_metrics.get(key) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 0b439c482fa01..89a30c841b20c 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -331,7 +331,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre dict( name="lr_scheduler_step", args=(ANY,), - kwargs=dict(optimizer_idx=None, monitor_val=None), + kwargs=dict(optimizer_idx=None, metrics=None), ) ] if i == (trainer.num_training_batches - 1) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index bc2ae6bf045fd..91a37243dc67d 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -693,10 +693,10 @@ def training_step(self, batch, batch_idx, optimizer_idx): def training_epoch_end(self, *args, **kwargs): pass - def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val=None): + def lr_scheduler_step(self, scheduler, optimizer_idx, metrics=None): # step-level if optimizer_idx == 0: - super().lr_scheduler_step(scheduler, optimizer_idx, monitor_val) + super().lr_scheduler_step(scheduler, optimizer_idx, metrics) # epoch-level elif optimizer_idx == 1: scheduler.step(epoch=self.current_epoch)