Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Dec 21, 2021
1 parent a1d1bc7 commit d2488d4
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 25 deletions.
24 changes: 11 additions & 13 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,25 +254,23 @@ If you want to call schedulers that require a metric value after each epoch, con

Bring your own Custom Learning Rate Schedulers
----------------------------------------------
Lightning allows custom learning rate schedulers which are not present in
`PyTorch natively <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
One good example is `Timm Schedulers <https://github.com/rwightman/pytorch-image-models/blob/master/timm/scheduler/scheduler.py>`_.
You can configure how your learning rate will be updated based on your custom implementation
and lightning will handle when they should be updated based on the scheduler config provided inside
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers`. For you custom
implementation you must override :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step`
if necessary. If you are using native PyTorch schedulers, there is no need to override this hook since
Lightning will handle it optimally by default.
Lightning allows using custom learning rate schedulers that aren't available in `PyTorch natively <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
One good example is `Timm Schedulers <https://github.com/rwightman/pytorch-image-models/blob/master/timm/scheduler/scheduler.py>`_. When using custom learning rate schedulers
relying on a different API from Native PyTorch ones, you should override the :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` with your desired logic.
If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it optimally by default.

.. code-block:: python
from timm.scheduler import TanhLRScheduler
.. testcode:: python
def configure_optimizers(self):
optimizer = ...
scheduler = ...
return [optimizer], [scheduler]
scheduler = TanhLRScheduler(optimizer, ...)
return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]
def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val=None):
def lr_scheduler_step(self, scheduler, optimizer_idx, metrics=None):
scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value
-----
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ def configure_optimizers(self):
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like ``ReduceLROnPlateau``
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
Expand Down Expand Up @@ -1497,8 +1497,8 @@ def lr_scheduler_step(
self,
scheduler: Any,
optimizer_idx: Optional[int] = None,
monitor_val: Optional[Union[float, torch.Tensor]] = None,
):
metrics: Optional[Union[float, torch.Tensor]] = None,
) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler.
Expand All @@ -1513,21 +1513,21 @@ def lr_scheduler_step(
Examples::
# DEFAULT
def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val):
def lr_scheduler_step(self, scheduler, optimizer_idx, metrics):
if monitor_val is None:
scheduler.step()
else:
scheduler.step(monitor_val)
# Alternative way to update schedulers if it requires an epoch value
def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val):
def lr_scheduler_step(self, scheduler, optimizer_idx, metrics):
scheduler.step(epoch=self.current_epoch)
"""
if monitor_val is None:
if metrics is None:
scheduler.step()
else:
scheduler.step(metrics=monitor_val)
scheduler.step(metrics=metrics)

def optimizer_step(
self,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,11 @@ def _update_learning_rates(
"lr_scheduler_step",
lr_scheduler["scheduler"],
optimizer_idx=lr_scheduler["opt_idx"],
monitor_val=monitor_val,
metrics=monitor_val,
)
self.scheduler_progress.increment_completed()

def _get_monitor_value(self, key: str) -> Optional[Union[float, torch.Tensor]]:
def _get_monitor_value(self, key: str) -> Any:
# this is a separate method to aid in testing
return self.trainer.callback_metrics.get(key)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
dict(
name="lr_scheduler_step",
args=(ANY,),
kwargs=dict(optimizer_idx=None, monitor_val=None),
kwargs=dict(optimizer_idx=None, metrics=None),
)
]
if i == (trainer.num_training_batches - 1)
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,10 +693,10 @@ def training_step(self, batch, batch_idx, optimizer_idx):
def training_epoch_end(self, *args, **kwargs):
pass

def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val=None):
def lr_scheduler_step(self, scheduler, optimizer_idx, metrics=None):
# step-level
if optimizer_idx == 0:
super().lr_scheduler_step(scheduler, optimizer_idx, monitor_val)
super().lr_scheduler_step(scheduler, optimizer_idx, metrics)
# epoch-level
elif optimizer_idx == 1:
scheduler.step(epoch=self.current_epoch)
Expand Down

0 comments on commit d2488d4

Please sign in to comment.