Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LightningModule.lr_scheduler_step #10249

Merged
merged 40 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3e10985
add LightningModule.scheduler_step
rohitgr7 Oct 29, 2021
9b07fa2
add tests
rohitgr7 Oct 30, 2021
a718c84
update types
rohitgr7 Oct 30, 2021
39059e5
docs
rohitgr7 Oct 30, 2021
3c66768
update .gitignore
rohitgr7 Oct 30, 2021
e437242
chlog
rohitgr7 Oct 30, 2021
fc8bc16
mypy
rohitgr7 Oct 30, 2021
b4dd1d8
remove step
rohitgr7 Dec 18, 2021
18e6bb4
add protocol api
rohitgr7 Dec 18, 2021
d7bdd0e
update
rohitgr7 Dec 18, 2021
ec2aa5d
add more test
rohitgr7 Dec 18, 2021
555c49f
use extensions
rohitgr7 Dec 18, 2021
5e8d371
register_hook
rohitgr7 Dec 18, 2021
f6b3e10
address reviews
rohitgr7 Dec 20, 2021
3c095ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2021
b01d2cf
fix and rebase
rohitgr7 Dec 28, 2021
78ebd31
mypy
rohitgr7 Dec 28, 2021
ff11e76
try fix mypy
rohitgr7 Jan 3, 2022
f8de4d0
try fix mypy
rohitgr7 Jan 3, 2022
404ba6b
try fix mypy
rohitgr7 Jan 3, 2022
013f9ce
use existing state_dict protocol
rohitgr7 Jan 3, 2022
78c8133
Merge branch 'master' into enhance/scheduler_step
rohitgr7 Jan 3, 2022
65bda5f
update import
rohitgr7 Jan 3, 2022
26182db
small updates
rohitgr7 Jan 4, 2022
54e2af9
Merge branch 'master' into enhance/scheduler_step
rohitgr7 Jan 4, 2022
b4fb944
add edge case check
rohitgr7 Jan 4, 2022
af8b1c3
rebase
rohitgr7 Jan 4, 2022
4c8ada6
avoid protocol
rohitgr7 Jan 5, 2022
08be795
Merge branch 'master' into enhance/scheduler_step
rohitgr7 Jan 7, 2022
c497bdf
move to types
rohitgr7 Jan 7, 2022
f1553ee
Inherit from the state dict protocol
carmocca Jan 8, 2022
99c92a5
All positional, optimizer index always int
carmocca Jan 8, 2022
ae8ae09
Simplify tests
carmocca Jan 8, 2022
236b55d
Minor test changes
carmocca Jan 8, 2022
7e82d1d
simplify test
rohitgr7 Jan 8, 2022
43532fd
one line
rohitgr7 Jan 8, 2022
4ec0e5c
Reduce further, test calls
carmocca Jan 8, 2022
b55504f
use typeerror
rohitgr7 Jan 10, 2022
281b0ef
Merge remote-tracking branch 'origin/master' into enhance/scheduler_step
rohitgr7 Jan 11, 2022
16797dd
Merge branch 'master' into enhance/scheduler_step
carmocca Jan 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ ENV/
.data/
Datasets/
mnist/
MNIST/
legacy/checkpoints/
*.gz
*ubyte


# pl tests
ml-runs/
mlruns/
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))


- Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249))


- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247))


- Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247))



### Changed

- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
Expand Down
23 changes: 23 additions & 0 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,29 @@ If you want to call schedulers that require a metric value after each epoch, con

-----

Bring your own Custom Learning Rate Schedulers
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
----------------------------------------------
Lightning allows using custom learning rate schedulers that aren't available in `PyTorch natively <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
One good example is `Timm Schedulers <https://github.com/rwightman/pytorch-image-models/blob/master/timm/scheduler/scheduler.py>`_. When using custom learning rate schedulers
relying on a different API from Native PyTorch ones, you should override the :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` with your desired logic.
If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it automatically by default.

.. code-block:: python

from timm.scheduler import TanhLRScheduler


def configure_optimizers(self):
optimizer = ...
scheduler = TanhLRScheduler(optimizer, ...)
return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]


def lr_scheduler_step(self, scheduler, optimizer_idx, metric=None):
scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value

-----

Use closure for LBFGS-like optimizers
-------------------------------------
It is a good practice to provide the optimizer with a closure function that performs a ``forward``, ``zero_grad`` and
Expand Down
38 changes: 37 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.parsing import collect_init_args
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -1493,6 +1493,42 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

def lr_scheduler_step(
self,
scheduler: LRSchedulerTypeUnion,
optimizer_idx: Optional[int] = None,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
metric: Any = None,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler.
By default, Lightning calls ``step()`` and as shown in the example
for each scheduler based on its ``interval``.

Args:
scheduler: Learning rate scheduler.
optimizer_idx: Index of the optimizer associated with this scheduler.
metric: Value of the monitor used for schedulers like ``ReduceLROnPlateau``.

Examples::

# DEFAULT
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
if metric is None:
scheduler.step()
else:
scheduler.step(metric)

# Alternative way to update schedulers if it requires an epoch value
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
scheduler.step(epoch=self.current_epoch)

"""
if metric is None:
scheduler.step()
else:
scheduler.step(metric)

def optimizer_step(
self,
epoch: int,
Expand Down
34 changes: 29 additions & 5 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import LRSchedulerTypeTuple


def do_nothing_closure() -> None:
Expand Down Expand Up @@ -168,7 +170,9 @@ def closure_dis():
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)


def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]:
def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
) -> Tuple[List[Optimizer], List[Dict[str, Any]], List[int]]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
model.trainer._lightning_optimizers = None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
Expand All @@ -185,6 +189,7 @@ def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[Lis
)
lr_schedulers = _configure_schedulers(lr_schedulers, monitor)
_set_scheduler_opt_idx(optimizers, lr_schedulers)
_validate_scheduler_api(lr_schedulers, model)
return optimizers, lr_schedulers, optimizer_frequencies


Expand Down Expand Up @@ -298,10 +303,9 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
lr_schedulers.append({**default_config, "scheduler": scheduler})

return lr_schedulers


Expand All @@ -325,9 +329,29 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -
lr_schedulers.append({**default_config, **scheduler})
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})

return lr_schedulers


def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.LightningModule") -> None:
from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict

for scheduler_config in lr_schedulers:
scheduler = scheduler_config["scheduler"]
if not isinstance(scheduler, _SupportsStateDict):
raise ValueError(
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."
" It should have `state_dict` and `load_state_dict` methods defined."
)

if not isinstance(scheduler, LRSchedulerTypeTuple) and not is_overridden("lr_scheduler_step", model):
raise MisconfigurationException(
f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow the PyTorch LR Scheduler"
" Protocol. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if"
" you are using a custom LR scheduler."
)
carmocca marked this conversation as resolved.
Show resolved Hide resolved


def _get_default_scheduler_config() -> Dict[str, Any]:
return {
"scheduler": None,
Expand All @@ -341,7 +365,7 @@ def _get_default_scheduler_config() -> Dict[str, Any]:
}


def _set_scheduler_opt_idx(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_schedulers: List[Dict[str, Any]]) -> None:
for sch in lr_schedulers:

for opt_idx, opt in enumerate(optimizers):
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,11 +506,12 @@ def _update_learning_rates(
self.scheduler_progress.increment_ready()

# update LR
if lr_scheduler["reduce_on_plateau"]:
lr_scheduler["scheduler"].step(monitor_val)
else:
lr_scheduler["scheduler"].step()

self.trainer._call_lightning_module_hook(
"lr_scheduler_step",
lr_scheduler["scheduler"],
optimizer_idx=lr_scheduler["opt_idx"],
metric=monitor_val,
)
self.scheduler_progress.increment_completed()

def _get_monitor_value(self, key: str) -> Any:
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import torch
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
Expand All @@ -41,7 +40,7 @@
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -399,7 +398,7 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
return self.model, [optimizer]

def _setup_model_and_optimizer(
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
):
"""Initialize one model and one optimizer with an optional learning rate scheduler.

Expand Down Expand Up @@ -445,7 +444,7 @@ def init_deepspeed(self):
else:
self._initialize_deepspeed_inference(model)

def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]:
def _init_optimizers(self) -> Tuple[Optimizer, Optional[List[LRSchedulerConfig]], Optional[int]]:
optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
if len(optimizers) > 1 or len(schedulers) > 1:
raise MisconfigurationException(
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/strategies/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -105,8 +104,7 @@ def _unpack_lightning_optimizer(opt):
lr_schedulers = self.lightning_module.trainer.lr_schedulers
for scheduler in lr_schedulers:
scheduler = scheduler["scheduler"]
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from torchmetrics import Metric
from typing_extensions import Protocol, runtime_checkable

import pytorch_lightning as pl
from pytorch_lightning.loops.utilities import _is_max_limit_reached
Expand Down Expand Up @@ -468,3 +469,14 @@ def hpc_save_path(folderpath: _PATH) -> str:
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt")
return filepath


@runtime_checkable
class _SupportsStateDict(Protocol):
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""

def state_dict(self) -> Dict[str, Any]:
...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class _LogOptions(TypedDict):
"optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"lr_scheduler_step": None,
"on_before_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
Expand Down
16 changes: 4 additions & 12 deletions pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
DataLoader,
IterableDataset,
)
from typing_extensions import Protocol, runtime_checkable

import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -577,6 +576,8 @@ def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dic
# In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset`
# therefore, we need to reload the states manually.

from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict

akihironitta marked this conversation as resolved.
Show resolved Hide resolved
latest_worker_id = state_dict["latest_worker_id"]
num_workers = state_dict["state"][latest_worker_id]["num_workers"]
sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None)
Expand Down Expand Up @@ -635,17 +636,6 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}


@runtime_checkable
class _SupportsStateDict(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""

def state_dict(self) -> Dict[str, Any]:
...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...


class _StatefulDataLoaderIter:
"""This mixin is used to make PyTorch DataLoaderIter stateful."""

Expand All @@ -656,6 +646,8 @@ def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None:

def _store_sampler_state(self) -> None:
"""This function is used to extract the sampler states if any."""
from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict

rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
sampler_state = {
k: v.state_dict()
for k, v in self._loader.__dict__.items()
Expand Down
14 changes: 2 additions & 12 deletions pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]


# Copied from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
# Inferred from `torch.optim.lr_scheduler.pyi`
class _LRScheduler:
optimizer: Optimizer

def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None:
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
...

def state_dict(self) -> dict:
Expand All @@ -60,15 +59,6 @@ def state_dict(self) -> dict:
def load_state_dict(self, state_dict: dict) -> None:
...

def get_last_lr(self) -> List[float]:
...

def get_lr(self) -> float:
...

def step(self, epoch: Optional[int] = ...) -> None:
...


# Copied from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
Expand Down
11 changes: 11 additions & 0 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,17 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
args=(current_epoch, i, ANY, 0, ANY),
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp),
),
*(
[
dict(
name="lr_scheduler_step",
args=(ANY,),
kwargs=dict(optimizer_idx=0, metric=None),
)
]
if i == (trainer.num_training_batches - 1)
else []
),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)),
dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)),
dict(name="Callback.on_batch_end", args=(trainer, model)),
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def test_fx_validator_integration(tmpdir):
"configure_callbacks": "You can't",
"on_validation_model_eval": "You can't",
"on_validation_model_train": "You can't",
"lr_scheduler_step": "You can't",
"summarize": "not managed by the `Trainer",
}
model = HookedModel(not_supported)
Expand Down
Loading