From 0b8d27d46261526ba13064490b6250010c188977 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 Jan 2022 20:52:05 +0000 Subject: [PATCH] Remove dependency on private `_get_default_scheduler_config` (#1099) --- flash/core/model.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index b0d1445db4..7c15e4916a 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -27,7 +27,6 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks.finetuning import BaseFinetuning -from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -939,12 +938,22 @@ def _get_lr_scheduler_class_from_registry(self, lr_scheduler_key: str) -> Dict[s return deepcopy(lr_scheduler_fn) def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: + default_scheduler_config = { + "scheduler": None, + "name": None, + "interval": "epoch", + "frequency": 1, + "reduce_on_plateau": False, + "monitor": None, + "strict": True, + "opt_idx": None, + } if isinstance(self.lr_scheduler, str): lr_scheduler_data: Dict[str, Any] = self._get_lr_scheduler_class_from_registry(self.lr_scheduler) lr_scheduler_fn = lr_scheduler_data.pop("fn") lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None) lr_scheduler_kwargs: Dict[str, Any] = {} - lr_scheduler_config = _get_default_scheduler_config() + lr_scheduler_config = default_scheduler_config for key, value in lr_scheduler_config.items(): lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value @@ -953,7 +962,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: lr_scheduler_fn = self.lr_scheduler lr_scheduler_metadata: Dict[str, Any] = None lr_scheduler_kwargs: Dict[str, Any] = {} - lr_scheduler_config = _get_default_scheduler_config() + lr_scheduler_config = default_scheduler_config elif isinstance(self.lr_scheduler, Tuple): if len(self.lr_scheduler) not in [2, 3]: @@ -964,7 +973,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: f"2) Of length 3 with the first index containing a str from {self.available_lr_schedulers()} and" f" the second index containing the required keyword arguments to initialize the LR Scheduler and" f" the third index containing a Lightning scheduler configuration dictionary of the format" - f" {_get_default_scheduler_config()}. NOTE: Do not set the `scheduler` key in the" + f" {default_scheduler_config}. NOTE: Do not set the `scheduler` key in the" f" lr_scheduler_config, it will overridden with an instance of the provided scheduler key." ) @@ -990,7 +999,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: lr_scheduler_fn = lr_scheduler_data.pop("fn") lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None) lr_scheduler_kwargs: Dict[str, Any] = self.lr_scheduler[1] - lr_scheduler_config = _get_default_scheduler_config() + lr_scheduler_config = default_scheduler_config for key, value in lr_scheduler_config.items(): lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value if len(self.lr_scheduler) == 3: @@ -1023,11 +1032,11 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: if not isinstance(lr_scheduler, (_LRScheduler, Dict)): raise MisconfigurationException( f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler" - f" configuration with keys belonging to {list(_get_default_scheduler_config().keys())}." + f" configuration with keys belonging to {list(default_scheduler_config.keys())}." ) if isinstance(lr_scheduler, Dict): - dummy_config = _get_default_scheduler_config() + dummy_config = default_scheduler_config if not all(config_key in dummy_config.keys() for config_key in lr_scheduler.keys()): raise MisconfigurationException( f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler"