From 82c8875f33addb0becd7761c95e9674ccc98c7ee Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 12 Jan 2022 09:23:49 +0530 Subject: [PATCH] Add `LightningModule.lr_scheduler_step` (#10249) Co-authored-by: Carlos Mocholi --- .gitignore | 2 +- CHANGELOG.md | 4 +- docs/source/common/optimizers.rst | 26 +++- pytorch_lightning/core/lightning.py | 38 +++++- pytorch_lightning/core/optimizer.py | 32 ++++- .../loops/epoch/training_epoch_loop.py | 11 +- pytorch_lightning/strategies/deepspeed.py | 7 +- pytorch_lightning/strategies/horovod.py | 4 +- .../logger_connector/fx_validator.py | 1 + pytorch_lightning/utilities/auto_restart.py | 14 +- pytorch_lightning/utilities/types.py | 39 ++---- tests/models/test_hooks.py | 5 + .../trainer/logging_/test_logger_connector.py | 1 + tests/trainer/optimization/test_optimizers.py | 122 ++++++++++++++++++ tests/utilities/test_auto_restart.py | 2 +- 15 files changed, 246 insertions(+), 62 deletions(-) diff --git a/.gitignore b/.gitignore index e1d165dd5dbb1..923c2a1829c22 100644 --- a/.gitignore +++ b/.gitignore @@ -139,11 +139,11 @@ ENV/ .data/ Datasets/ mnist/ +MNIST/ legacy/checkpoints/ *.gz *ubyte - # pl tests ml-runs/ mlruns/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 8583a8365ed9c..26392c4689649 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,13 +63,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990)) +- Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249)) + + - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247)) - Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247)) - ### Changed - Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index c48224e3a2f52..4ed54aed66410 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -518,9 +518,31 @@ to perform a step, Lightning won't be able to support accelerators, precision an optimizer.step(closure=optimizer_closure) -*************************** +Bring your own Custom Learning Rate Schedulers +============================================== + +Lightning allows using custom learning rate schedulers that aren't available in `PyTorch natively `_. +One good example is `Timm Schedulers `_. When using custom learning rate schedulers +relying on a different API from Native PyTorch ones, you should override the :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` with your desired logic. +If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it automatically by default. + +.. code-block:: python + + from timm.scheduler import TanhLRScheduler + + + def configure_optimizers(self): + optimizer = ... + scheduler = TanhLRScheduler(optimizer, ...) + return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] + + + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value + + Configure Gradient Clipping -*************************** +=========================== To configure custom gradient clipping, consider overriding the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ee42cceff7b43..25472d5295cbe 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -53,7 +53,7 @@ from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.parsing import collect_init_args from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -1493,6 +1493,42 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm ) + def lr_scheduler_step( + self, + scheduler: LRSchedulerTypeUnion, + optimizer_idx: int, + metric: Optional[Any], + ) -> None: + r""" + Override this method to adjust the default way the + :class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler. + By default, Lightning calls ``step()`` and as shown in the example + for each scheduler based on its ``interval``. + + Args: + scheduler: Learning rate scheduler. + optimizer_idx: Index of the optimizer associated with this scheduler. + metric: Value of the monitor used for schedulers like ``ReduceLROnPlateau``. + + Examples:: + + # DEFAULT + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + if metric is None: + scheduler.step() + else: + scheduler.step(metric) + + # Alternative way to update schedulers if it requires an epoch value + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + scheduler.step(epoch=self.current_epoch) + + """ + if metric is None: + scheduler.step() + else: + scheduler.step(metric) + def optimizer_step( self, epoch: int, diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 419799b8db051..3b0cdffff497e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -23,6 +23,8 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import _SupportsStateDict, LRSchedulerTypeTuple def do_nothing_closure() -> None: @@ -168,7 +170,9 @@ def closure_dis(): trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) -def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]: +def _init_optimizers_and_lr_schedulers( + model: "pl.LightningModule", +) -> Tuple[List[Optimizer], List[Dict[str, Any]], List[int]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" model.trainer._lightning_optimizers = None optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) @@ -185,6 +189,7 @@ def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[Lis ) lr_schedulers = _configure_schedulers(lr_schedulers, monitor) _set_scheduler_opt_idx(optimizers, lr_schedulers) + _validate_scheduler_api(lr_schedulers, model) return optimizers, lr_schedulers, optimizer_frequencies @@ -298,10 +303,9 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] lr_schedulers.append( {**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor} ) - elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): - lr_schedulers.append({**default_config, "scheduler": scheduler}) else: - raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') + lr_schedulers.append({**default_config, "scheduler": scheduler}) + return lr_schedulers @@ -325,9 +329,27 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - lr_schedulers.append({**default_config, **scheduler}) else: lr_schedulers.append({**default_config, "scheduler": scheduler}) + return lr_schedulers +def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.LightningModule") -> None: + for scheduler_config in lr_schedulers: + scheduler = scheduler_config["scheduler"] + if not isinstance(scheduler, _SupportsStateDict): + raise TypeError( + f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid." + " It should have `state_dict` and `load_state_dict` methods defined." + ) + + if not isinstance(scheduler, LRSchedulerTypeTuple) and not is_overridden("lr_scheduler_step", model): + raise MisconfigurationException( + f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler" + " API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if" + " you are using a custom LR scheduler." + ) + + def _get_default_scheduler_config() -> Dict[str, Any]: return { "scheduler": None, @@ -341,7 +363,7 @@ def _get_default_scheduler_config() -> Dict[str, Any]: } -def _set_scheduler_opt_idx(optimizers: List[Any], lr_schedulers: List[Any]) -> None: +def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_schedulers: List[Dict[str, Any]]) -> None: for sch in lr_schedulers: for opt_idx, opt in enumerate(optimizers): diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 48aa75fe70648..69432ee07dd0b 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -454,11 +454,12 @@ def _update_learning_rates( self.scheduler_progress.increment_ready() # update LR - if lr_scheduler["reduce_on_plateau"]: - lr_scheduler["scheduler"].step(monitor_val) - else: - lr_scheduler["scheduler"].step() - + self.trainer._call_lightning_module_hook( + "lr_scheduler_step", + lr_scheduler["scheduler"], + lr_scheduler["opt_idx"], + monitor_val, + ) self.scheduler_progress.increment_completed() def _get_monitor_value(self, key: str) -> Any: diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 04ab2ccb787c3..7504eb37500b2 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -24,7 +24,6 @@ import torch from torch.nn import Module from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers @@ -41,7 +40,7 @@ from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache warning_cache = WarningCache() @@ -399,7 +398,7 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer] return self.model, [optimizer] def _setup_model_and_optimizer( - self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None + self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None ): """Initialize one model and one optimizer with an optional learning rate scheduler. @@ -445,7 +444,7 @@ def init_deepspeed(self): else: self._initialize_deepspeed_inference(model) - def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]: + def _init_optimizers(self) -> Tuple[Optimizer, Optional[List[LRSchedulerConfig]], Optional[int]]: optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(schedulers) > 1: raise MisconfigurationException( diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index 7d7ecbe6a7b2c..a1c34fa87b8d5 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer @@ -105,8 +104,7 @@ def _unpack_lightning_optimizer(opt): lr_schedulers = self.lightning_module.trainer.lr_schedulers for scheduler in lr_schedulers: scheduler = scheduler["scheduler"] - if isinstance(scheduler, _LRScheduler): - scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] + scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index e73bf54825269..c33320185d76a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -43,6 +43,7 @@ class _LogOptions(TypedDict): "optimizer_step": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), + "lr_scheduler_step": None, "on_before_zero_grad": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 9d26f4a6e0736..ec630f795d8cc 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -36,13 +36,13 @@ DataLoader, IterableDataset, ) -from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _SupportsStateDict class FastForwardSampler(Sampler): @@ -576,7 +576,6 @@ def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: # In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset` # therefore, we need to reload the states manually. - latest_worker_id = state_dict["latest_worker_id"] num_workers = state_dict["state"][latest_worker_id]["num_workers"] sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None) @@ -635,17 +634,6 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state} -@runtime_checkable -class _SupportsStateDict(Protocol): - """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" - - def state_dict(self) -> Dict[str, Any]: - ... - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - ... - - class _StatefulDataLoaderIter: """This mixin is used to make PyTorch DataLoaderIter stateful.""" diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 44a3b88d530d6..1d5cd272267d5 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -23,7 +23,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric -from typing_extensions import TypedDict +from typing_extensions import Protocol, runtime_checkable, TypedDict _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] @@ -46,33 +46,29 @@ EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] -# Copied from `torch.optim.lr_scheduler.pyi` -# Missing attributes were added to improve typing -class _LRScheduler: - optimizer: Optimizer +@runtime_checkable +class _SupportsStateDict(Protocol): + """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" - def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: + def state_dict(self) -> Dict[str, Any]: ... - def state_dict(self) -> dict: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... - def load_state_dict(self, state_dict: dict) -> None: - ... - def get_last_lr(self) -> List[float]: - ... - - def get_lr(self) -> float: - ... +# Inferred from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +class _LRScheduler(_SupportsStateDict): + optimizer: Optimizer - def step(self, epoch: Optional[int] = ...) -> None: + def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: ... -# Copied from `torch.optim.lr_scheduler.pyi` +# Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class ReduceLROnPlateau: +class ReduceLROnPlateau(_SupportsStateDict): in_cooldown: bool optimizer: Optimizer @@ -91,15 +87,6 @@ def __init__( ) -> None: ... - def step(self, metrics: Any, epoch: Optional[int] = ...) -> None: - ... - - def state_dict(self) -> dict: - ... - - def load_state_dict(self, state_dict: dict) -> None: - ... - # todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index b2c75f142427a..5f20d7bb4115a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -326,6 +326,11 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre args=(current_epoch, i, ANY, 0, ANY), kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp), ), + *( + [dict(name="lr_scheduler_step", args=(ANY, 0, None))] + if i == (trainer.num_training_batches - 1) + else [] + ), dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)), dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)), dict(name="Callback.on_batch_end", args=(trainer, model)), diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b965478684b87..8d6b3551e8579 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -233,6 +233,7 @@ def test_fx_validator_integration(tmpdir): "configure_callbacks": "You can't", "on_validation_model_eval": "You can't", "on_validation_model_train": "You can't", + "lr_scheduler_step": "You can't", "summarize": "not managed by the `Trainer", } model = HookedModel(not_supported) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 4d646231027ed..e960eabcb9b62 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock +from unittest.mock import call, patch import pytest import torch @@ -679,3 +680,124 @@ def on_save_checkpoint(self, checkpoint): model.training_epoch_end = None trainer.fit(model) assert model.on_save_checkpoint_called + + +def test_lr_scheduler_step_hook(tmpdir): + """Test that custom lr scheduler works and `lr_scheduler_step` is called at appropriate time.""" + + class CustomEpochScheduler: + def __init__(self, optimizer): + self.optimizer = optimizer + + def step(self, epoch): + ... + + def state_dict(self): + ... + + def load_state_dict(self, state_dict): + ... + + class CustomBoringModel(BoringModel): + def training_step(self, batch, batch_idx, optimizer_idx=0): + return super().training_step(batch, batch_idx) + + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + # step-level + if optimizer_idx == 0: + super().lr_scheduler_step(scheduler, optimizer_idx, metric) + # epoch-level + elif optimizer_idx == 1: + scheduler.step(epoch=self.current_epoch) + + def configure_optimizers(self): + opt1 = torch.optim.SGD(self.layer.parameters(), lr=1e-2) + lr_scheduler1 = {"scheduler": torch.optim.lr_scheduler.StepLR(opt1, step_size=1), "interval": "step"} + opt2 = torch.optim.SGD(self.layer.parameters(), lr=1e-2) + lr_scheduler2 = CustomEpochScheduler(opt2) + return {"optimizer": opt1, "lr_scheduler": lr_scheduler1}, { + "optimizer": opt2, + "lr_scheduler": lr_scheduler2, + } + + model = CustomBoringModel() + model.training_epoch_end = None + max_epochs = 3 + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + enable_checkpointing=False, + logger=False, + max_epochs=max_epochs, + limit_train_batches=limit_train_batches, + limit_val_batches=0, + ) + + with patch.object(CustomEpochScheduler, "step") as mock_method_epoch, patch.object( + torch.optim.lr_scheduler.StepLR, "step" + ) as mock_method_step: + trainer.fit(model) + + assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)] + # first step is called by PyTorch _LRScheduler + assert mock_method_step.call_count == max_epochs * limit_train_batches + 1 + + +def test_invalid_scheduler_missing_state_dict(): + """Test that custom lr scheduler raises an error if it's missing the state dict.""" + + class CustomScheduler: + def __init__(self, optimizer): + self.optimizer = optimizer + + def step(self): + ... + + class CustomBoringModel(BoringModel): + def configure_optimizers(self): + opt = torch.optim.SGD(self.parameters(), lr=1e-2) + lr_scheduler = CustomScheduler(opt) + return {"optimizer": opt, "lr_scheduler": lr_scheduler} + + model = CustomBoringModel() + model.trainer = Trainer() + with pytest.raises(TypeError, match="provided lr scheduler `CustomScheduler` is invalid"): + _init_optimizers_and_lr_schedulers(model) + + +@pytest.mark.parametrize("override", (False, True)) +def test_invalid_lr_scheduler_with_custom_step_method(override): + """Test that custom lr scheduler raises an error if it doesn't follow PyTorch LR Scheduler API.""" + + class CustomScheduler: + def __init__(self, optimizer): + self.optimizer = optimizer + + def step(self, foobar): # breaks the API, forces user to override `lr_scheduler_step` + ... + + def state_dict(self): + ... + + def load_state_dict(self, state_dict): + ... + + class CustomBoringModel(BoringModel): + def configure_optimizers(self): + opt = torch.optim.SGD(self.parameters(), lr=1e-2) + lr_scheduler = CustomScheduler(opt) + return {"optimizer": opt, "lr_scheduler": lr_scheduler} + + model = CustomBoringModel() + model.trainer = Trainer() + if override: + + def lr_scheduler_step(*_): + ... + + # the user did override the hook, no error + model.lr_scheduler_step = lr_scheduler_step + _init_optimizers_and_lr_schedulers(model) + else: + with pytest.raises(MisconfigurationException, match="CustomScheduler` doesn't follow"): + _init_optimizers_and_lr_schedulers(model) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 6e26651ddb078..e467436238f31 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -48,7 +48,6 @@ _reload_dataloader_state_dict, _rotate_worker_indices, _SingleProcessDataLoaderIterStateful, - _SupportsStateDict, _teardown_dataloader_get_iterators, _validate_fault_tolerant_automatic, CaptureIterableDataset, @@ -60,6 +59,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training +from pytorch_lightning.utilities.types import _SupportsStateDict from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf