diff --git a/flash/core/model.py b/flash/core/model.py index ccaebcffb6..02dc367932 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -27,7 +27,7 @@ from torch.optim.optimizer import Optimizer from flash.core.registry import FlashRegistry -from flash.core.schedulers import _SCHEDULER_REGISTRY +from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -68,7 +68,7 @@ class Task(LightningModule): postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task. """ - schedulers = _SCHEDULER_REGISTRY + schedulers: FlashRegistry = _SCHEDULERS_REGISTRY def __init__( self, @@ -181,7 +181,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A batch = torch.stack(batch) return self(batch) - def configure_optimizers(self): + def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: optimizer = self.optimizer if not isinstance(self.optimizer, Optimizer): self.optimizer_kwargs["lr"] = self.learning_rate @@ -352,6 +352,8 @@ def available_schedulers(cls) -> List[str]: def get_num_training_steps(self) -> int: """Total training steps inferred from datamodule and devices.""" + if not getattr(self, "trainer", None): + raise MisconfigurationException("The LightningModule isn't attached to the trainer yet.") if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0: dataset_size = self.trainer.limit_train_batches elif isinstance(self.trainer.limit_train_batches, float): @@ -373,10 +375,14 @@ def get_num_training_steps(self) -> int: return max_estimated_steps def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> int: + if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0): + raise MisconfigurationException( + "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" + ) if isinstance(num_warmup_steps, float): # Convert float values to percentage of training steps to use as warmup num_warmup_steps *= num_training_steps - return int(num_warmup_steps) + return round(num_warmup_steps) def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: scheduler = self.scheduler @@ -384,22 +390,12 @@ def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: return scheduler if isinstance(scheduler, str): scheduler_fn = self.schedulers.get(self.scheduler) - if "num_warmup_steps" in self.scheduler_kwargs: - num_warmup_steps = self.scheduler_kwargs.get("num_warmup_steps") - if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0): - raise MisconfigurationException( - "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" - ) - num_training_steps = self.get_num_training_steps() - num_warmup_steps = self._compute_warmup( - num_training_steps=num_training_steps, - num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), - ) - return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) - else: - raise MisconfigurationException( - "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" - ) + num_training_steps: int = self.get_num_training_steps() + num_warmup_steps: int = self._compute_warmup( + num_training_steps=num_training_steps, + num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), + ) + return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) elif issubclass(scheduler, _LRScheduler): return scheduler(optimizer, **self.scheduler_kwargs) raise MisconfigurationException( diff --git a/flash/core/schedulers.py b/flash/core/schedulers.py index bff5f4d138..eee60cc8f8 100644 --- a/flash/core/schedulers.py +++ b/flash/core/schedulers.py @@ -1,10 +1,14 @@ +from typing import Callable, List + from flash.core.registry import FlashRegistry from flash.utils.imports import _TRANSFORMERS_AVAILABLE -_SCHEDULER_REGISTRY = FlashRegistry("scheduler") +_SCHEDULERS_REGISTRY = FlashRegistry("scheduler") if _TRANSFORMERS_AVAILABLE: from transformers import optimization - functions = [getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')] + functions: List[Callable] = [ + getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler') + ] for fn in functions: - _SCHEDULER_REGISTRY(fn, name=fn.__name__[4:]) + _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:]) diff --git a/flash/utils/imports.py b/flash/utils/imports.py index a381141b45..5252a3e3d5 100644 --- a/flash/utils/imports.py +++ b/flash/utils/imports.py @@ -5,8 +5,4 @@ _COCO_AVAILABLE = _module_available("pycocotools") _TIMM_AVAILABLE = _module_available("timm") _TORCHVISION_AVAILABLE = _module_available("torchvision") -try: - import transformers - _TRANSFORMERS_AVAILABLE = True -except ModuleNotFoundError: - _TRANSFORMERS_AVAILABLE = False +_TRANSFORMERS_AVAILABLE = _module_available("transformers") diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 2426cff78f..450b662dbd 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -206,9 +206,7 @@ def test_optimization(tmpdir): ] optim = torch.optim.Adadelta(model.parameters()) - with pytest.raises( - MisconfigurationException, match="`num_warmup_steps` should be provided as float between 0 and 1" - ): + with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") optimizer, scheduler = task.configure_optimizers()