Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Remove dependency on private _get_default_scheduler_config (#1099)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Jan 5, 2022
1 parent b931fc1 commit 0b8d27d
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -939,12 +938,22 @@ def _get_lr_scheduler_class_from_registry(self, lr_scheduler_key: str) -> Dict[s
return deepcopy(lr_scheduler_fn)

def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
default_scheduler_config = {
"scheduler": None,
"name": None,
"interval": "epoch",
"frequency": 1,
"reduce_on_plateau": False,
"monitor": None,
"strict": True,
"opt_idx": None,
}
if isinstance(self.lr_scheduler, str):
lr_scheduler_data: Dict[str, Any] = self._get_lr_scheduler_class_from_registry(self.lr_scheduler)
lr_scheduler_fn = lr_scheduler_data.pop("fn")
lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None)
lr_scheduler_kwargs: Dict[str, Any] = {}
lr_scheduler_config = _get_default_scheduler_config()
lr_scheduler_config = default_scheduler_config
for key, value in lr_scheduler_config.items():
lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value

Expand All @@ -953,7 +962,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
lr_scheduler_fn = self.lr_scheduler
lr_scheduler_metadata: Dict[str, Any] = None
lr_scheduler_kwargs: Dict[str, Any] = {}
lr_scheduler_config = _get_default_scheduler_config()
lr_scheduler_config = default_scheduler_config

elif isinstance(self.lr_scheduler, Tuple):
if len(self.lr_scheduler) not in [2, 3]:
Expand All @@ -964,7 +973,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
f"2) Of length 3 with the first index containing a str from {self.available_lr_schedulers()} and"
f" the second index containing the required keyword arguments to initialize the LR Scheduler and"
f" the third index containing a Lightning scheduler configuration dictionary of the format"
f" {_get_default_scheduler_config()}. NOTE: Do not set the `scheduler` key in the"
f" {default_scheduler_config}. NOTE: Do not set the `scheduler` key in the"
f" lr_scheduler_config, it will overridden with an instance of the provided scheduler key."
)

Expand All @@ -990,7 +999,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
lr_scheduler_fn = lr_scheduler_data.pop("fn")
lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None)
lr_scheduler_kwargs: Dict[str, Any] = self.lr_scheduler[1]
lr_scheduler_config = _get_default_scheduler_config()
lr_scheduler_config = default_scheduler_config
for key, value in lr_scheduler_config.items():
lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value
if len(self.lr_scheduler) == 3:
Expand Down Expand Up @@ -1023,11 +1032,11 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
if not isinstance(lr_scheduler, (_LRScheduler, Dict)):
raise MisconfigurationException(
f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler"
f" configuration with keys belonging to {list(_get_default_scheduler_config().keys())}."
f" configuration with keys belonging to {list(default_scheduler_config.keys())}."
)

if isinstance(lr_scheduler, Dict):
dummy_config = _get_default_scheduler_config()
dummy_config = default_scheduler_config
if not all(config_key in dummy_config.keys() for config_key in lr_scheduler.keys()):
raise MisconfigurationException(
f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler"
Expand Down

0 comments on commit 0b8d27d

Please sign in to comment.