From 79f271ed8c37e079fd0284cbf6d67d747147b091 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 21 Apr 2021 20:40:24 +0100 Subject: [PATCH] [feat] Add support for schedulers (#232) * add support for schedulers * update changelog * resolve typing * update task * change for log softmax * udpate on comments --- CHANGELOG.md | 5 ++- flash/core/model.py | 87 +++++++++++++++++++++++++++++++++++-- flash/core/schedulers.py | 14 ++++++ flash/data/data_pipeline.py | 3 ++ flash/utils/imports.py | 1 + tests/core/test_model.py | 65 +++++++++++++++++++++++++++ 6 files changed, 170 insertions(+), 5 deletions(-) create mode 100644 flash/core/schedulers.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c50aaab6da..627e907654 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Switch to use `torchmetrics` ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169)) +- Better support for `optimizer` and `schedulers` ([#232](https://github.com/PyTorchLightning/lightning-flash/pull/232)) + + ### Fixed @@ -28,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added `RetinaNet` & `backbones` to `ObjectDetector` Task ([#121](https://github.com/PyTorchLightning/lightning-flash/pull/121)) -- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116), +- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116), [#117](https://github.com/PyTorchLightning/lightning-flash/pull/117), [#118](https://github.com/PyTorchLightning/lightning-flash/pull/118)) diff --git a/flash/core/model.py b/flash/core/model.py index 9914b4cb61..02dc367932 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -21,9 +21,13 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer from flash.core.registry import FlashRegistry +from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -64,11 +68,16 @@ class Task(LightningModule): postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task. """ + schedulers: FlashRegistry = _SCHEDULERS_REGISTRY + def __init__( self, model: Optional[nn.Module] = None, loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, preprocess: Preprocess = None, @@ -78,7 +87,11 @@ def __init__( if model is not None: self.model = model self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) - self.optimizer_cls = optimizer + self.optimizer = optimizer + self.scheduler = scheduler + self.optimizer_kwargs = optimizer_kwargs or {} + self.scheduler_kwargs = scheduler_kwargs or {} + self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics @@ -168,8 +181,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A batch = torch.stack(batch) return self(batch) - def configure_optimizers(self) -> torch.optim.Optimizer: - return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) + def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: + optimizer = self.optimizer + if not isinstance(self.optimizer, Optimizer): + self.optimizer_kwargs["lr"] = self.learning_rate + optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs) + if self.scheduler: + return [optimizer], [self._instantiate_scheduler(optimizer)] + return optimizer def configure_finetune_callback(self) -> List[Callback]: return [] @@ -323,3 +342,63 @@ def available_models(cls) -> List[str]: if registry is None: return [] return registry.available_keys() + + @classmethod + def available_schedulers(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None) + if registry is None: + return [] + return registry.available_keys() + + def get_num_training_steps(self) -> int: + """Total training steps inferred from datamodule and devices.""" + if not getattr(self, "trainer", None): + raise MisconfigurationException("The LightningModule isn't attached to the trainer yet.") + if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0: + dataset_size = self.trainer.limit_train_batches + elif isinstance(self.trainer.limit_train_batches, float): + # limit_train_batches is a percentage of batches + dataset_size = len(self.train_dataloader()) + dataset_size = int(dataset_size * self.trainer.limit_train_batches) + else: + dataset_size = len(self.train_dataloader()) + + num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) + if self.trainer.tpu_cores: + num_devices = max(num_devices, self.trainer.tpu_cores) + + effective_batch_size = self.trainer.accumulate_grad_batches * num_devices + max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs + + if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps: + return self.trainer.max_steps + return max_estimated_steps + + def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> int: + if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0): + raise MisconfigurationException( + "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" + ) + if isinstance(num_warmup_steps, float): + # Convert float values to percentage of training steps to use as warmup + num_warmup_steps *= num_training_steps + return round(num_warmup_steps) + + def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: + scheduler = self.scheduler + if isinstance(scheduler, _LRScheduler): + return scheduler + if isinstance(scheduler, str): + scheduler_fn = self.schedulers.get(self.scheduler) + num_training_steps: int = self.get_num_training_steps() + num_warmup_steps: int = self._compute_warmup( + num_training_steps=num_training_steps, + num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), + ) + return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) + elif issubclass(scheduler, _LRScheduler): + return scheduler(optimizer, **self.scheduler_kwargs) + raise MisconfigurationException( + "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` " + f"or a built-in scheduler in {self.available_schedulers()}" + ) diff --git a/flash/core/schedulers.py b/flash/core/schedulers.py new file mode 100644 index 0000000000..eee60cc8f8 --- /dev/null +++ b/flash/core/schedulers.py @@ -0,0 +1,14 @@ +from typing import Callable, List + +from flash.core.registry import FlashRegistry +from flash.utils.imports import _TRANSFORMERS_AVAILABLE + +_SCHEDULERS_REGISTRY = FlashRegistry("scheduler") + +if _TRANSFORMERS_AVAILABLE: + from transformers import optimization + functions: List[Callable] = [ + getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler') + ] + for fn in functions: + _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:]) diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 46be3c823c..fe75404f1c 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -16,6 +16,7 @@ import weakref from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union +import torch from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import imports @@ -285,6 +286,8 @@ def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None def _attach_preprocess_to_model( self, model: 'Task', stage: Optional[RunningStage] = None, device_transform_only: bool = False ) -> None: + device_collate_fn = torch.nn.Identity() + if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] diff --git a/flash/utils/imports.py b/flash/utils/imports.py index 5e17ba6d3e..5252a3e3d5 100644 --- a/flash/utils/imports.py +++ b/flash/utils/imports.py @@ -5,3 +5,4 @@ _COCO_AVAILABLE = _module_available("pycocotools") _TIMM_AVAILABLE = _module_available("timm") _TORCHVISION_AVAILABLE = _module_available("torchvision") +_TRANSFORMERS_AVAILABLE = _module_available("transformers") diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d0b0048b23..450b662dbd 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -20,12 +20,16 @@ import pytorch_lightning as pl import torch from PIL import Image +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor from torch.nn import functional as F +from torch.utils.data import DataLoader +import flash from flash.core.classification import ClassificationTask from flash.tabular import TabularClassifier from flash.text import SummarizationTask, TextClassifier +from flash.utils.imports import _TRANSFORMERS_AVAILABLE from flash.vision import ImageClassificationData, ImageClassifier # ======== Mock functions ======== @@ -160,3 +164,64 @@ class Foo(ImageClassifier): backbones = None assert Foo.available_backbones() == [] + + +def test_optimization(tmpdir): + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) + optim = torch.optim.Adam(model.parameters()) + task = ClassificationTask(model, optimizer=optim, scheduler=None) + + optimizer = task.configure_optimizers() + assert optimizer == optim + + task = ClassificationTask(model, optimizer=torch.optim.Adadelta, optimizer_kwargs={"eps": 0.5}, scheduler=None) + optimizer = task.configure_optimizers() + assert isinstance(optimizer, torch.optim.Adadelta) + assert optimizer.defaults["eps"] == 0.5 + + task = ClassificationTask( + model, + optimizer=torch.optim.Adadelta, + scheduler=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 1} + ) + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + optim = torch.optim.Adadelta(model.parameters()) + task = ClassificationTask(model, optimizer=optim, scheduler=torch.optim.lr_scheduler.StepLR(optim, step_size=1)) + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + if _TRANSFORMERS_AVAILABLE: + from transformers.optimization import get_linear_schedule_with_warmup + + assert task.available_schedulers() == [ + 'constant_schedule', 'constant_schedule_with_warmup', 'cosine_schedule_with_warmup', + 'cosine_with_hard_restarts_schedule_with_warmup', 'linear_schedule_with_warmup', + 'polynomial_decay_schedule_with_warmup' + ] + + optim = torch.optim.Adadelta(model.parameters()) + with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): + task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") + optimizer, scheduler = task.configure_optimizers() + + task = ClassificationTask( + model, + optimizer=optim, + scheduler="linear_schedule_with_warmup", + scheduler_kwargs={"num_warmup_steps": 0.1}, + loss_fn=F.nll_loss, + ) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=2) + ds = DummyDataset() + trainer.fit(task, train_dataloader=DataLoader(ds)) + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) + expected = get_linear_schedule_with_warmup.__name__ + assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected