Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] Add support for schedulers (#232)
Browse files Browse the repository at this point in the history
* add support for schedulers

* update changelog

* resolve typing

* update task

* change for log softmax

* udpate on comments
  • Loading branch information
tchaton committed Apr 21, 2021
1 parent 349c88c commit 79f271e
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 5 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Switch to use `torchmetrics` ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169))

- Better support for `optimizer` and `schedulers` ([#232](https://github.com/PyTorchLightning/lightning-flash/pull/232))



### Fixed

Expand All @@ -28,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `RetinaNet` & `backbones` to `ObjectDetector` Task ([#121](https://github.com/PyTorchLightning/lightning-flash/pull/121))
- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116),
- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116),
[#117](https://github.com/PyTorchLightning/lightning-flash/pull/117),
[#118](https://github.com/PyTorchLightning/lightning-flash/pull/118))

Expand Down
87 changes: 83 additions & 4 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer

from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess

Expand Down Expand Up @@ -64,11 +68,16 @@ class Task(LightningModule):
postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task.
"""

schedulers: FlashRegistry = _SCHEDULERS_REGISTRY

def __init__(
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
preprocess: Preprocess = None,
Expand All @@ -78,7 +87,11 @@ def __init__(
if model is not None:
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
self.optimizer_cls = optimizer
self.optimizer = optimizer
self.scheduler = scheduler
self.optimizer_kwargs = optimizer_kwargs or {}
self.scheduler_kwargs = scheduler_kwargs or {}

self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
Expand Down Expand Up @@ -168,8 +181,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
batch = torch.stack(batch)
return self(batch)

def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]:
optimizer = self.optimizer
if not isinstance(self.optimizer, Optimizer):
self.optimizer_kwargs["lr"] = self.learning_rate
optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs)
if self.scheduler:
return [optimizer], [self._instantiate_scheduler(optimizer)]
return optimizer

def configure_finetune_callback(self) -> List[Callback]:
return []
Expand Down Expand Up @@ -323,3 +342,63 @@ def available_models(cls) -> List[str]:
if registry is None:
return []
return registry.available_keys()

@classmethod
def available_schedulers(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None)
if registry is None:
return []
return registry.available_keys()

def get_num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if not getattr(self, "trainer", None):
raise MisconfigurationException("The LightningModule isn't attached to the trainer yet.")
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
# limit_train_batches is a percentage of batches
dataset_size = len(self.train_dataloader())
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len(self.train_dataloader())

num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)

effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs

if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
return self.trainer.max_steps
return max_estimated_steps

def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> int:
if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0):
raise MisconfigurationException(
"`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`"
)
if isinstance(num_warmup_steps, float):
# Convert float values to percentage of training steps to use as warmup
num_warmup_steps *= num_training_steps
return round(num_warmup_steps)

def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler:
scheduler = self.scheduler
if isinstance(scheduler, _LRScheduler):
return scheduler
if isinstance(scheduler, str):
scheduler_fn = self.schedulers.get(self.scheduler)
num_training_steps: int = self.get_num_training_steps()
num_warmup_steps: int = self._compute_warmup(
num_training_steps=num_training_steps,
num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"),
)
return scheduler_fn(optimizer, num_warmup_steps, num_training_steps)
elif issubclass(scheduler, _LRScheduler):
return scheduler(optimizer, **self.scheduler_kwargs)
raise MisconfigurationException(
"scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` "
f"or a built-in scheduler in {self.available_schedulers()}"
)
14 changes: 14 additions & 0 deletions flash/core/schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Callable, List

from flash.core.registry import FlashRegistry
from flash.utils.imports import _TRANSFORMERS_AVAILABLE

_SCHEDULERS_REGISTRY = FlashRegistry("scheduler")

if _TRANSFORMERS_AVAILABLE:
from transformers import optimization
functions: List[Callable] = [
getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')
]
for fn in functions:
_SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:])
3 changes: 3 additions & 0 deletions flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import weakref
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import imports
Expand Down Expand Up @@ -285,6 +286,8 @@ def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None
def _attach_preprocess_to_model(
self, model: 'Task', stage: Optional[RunningStage] = None, device_transform_only: bool = False
) -> None:
device_collate_fn = torch.nn.Identity()

if not stage:
stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]

Expand Down
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_TRANSFORMERS_AVAILABLE = _module_available("transformers")
65 changes: 65 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
import pytorch_lightning as pl
import torch
from PIL import Image
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader

import flash
from flash.core.classification import ClassificationTask
from flash.tabular import TabularClassifier
from flash.text import SummarizationTask, TextClassifier
from flash.utils.imports import _TRANSFORMERS_AVAILABLE
from flash.vision import ImageClassificationData, ImageClassifier

# ======== Mock functions ========
Expand Down Expand Up @@ -160,3 +164,64 @@ class Foo(ImageClassifier):
backbones = None

assert Foo.available_backbones() == []


def test_optimization(tmpdir):

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax())
optim = torch.optim.Adam(model.parameters())
task = ClassificationTask(model, optimizer=optim, scheduler=None)

optimizer = task.configure_optimizers()
assert optimizer == optim

task = ClassificationTask(model, optimizer=torch.optim.Adadelta, optimizer_kwargs={"eps": 0.5}, scheduler=None)
optimizer = task.configure_optimizers()
assert isinstance(optimizer, torch.optim.Adadelta)
assert optimizer.defaults["eps"] == 0.5

task = ClassificationTask(
model,
optimizer=torch.optim.Adadelta,
scheduler=torch.optim.lr_scheduler.StepLR,
scheduler_kwargs={"step_size": 1}
)
optimizer, scheduler = task.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adadelta)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR)

optim = torch.optim.Adadelta(model.parameters())
task = ClassificationTask(model, optimizer=optim, scheduler=torch.optim.lr_scheduler.StepLR(optim, step_size=1))
optimizer, scheduler = task.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adadelta)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR)

if _TRANSFORMERS_AVAILABLE:
from transformers.optimization import get_linear_schedule_with_warmup

assert task.available_schedulers() == [
'constant_schedule', 'constant_schedule_with_warmup', 'cosine_schedule_with_warmup',
'cosine_with_hard_restarts_schedule_with_warmup', 'linear_schedule_with_warmup',
'polynomial_decay_schedule_with_warmup'
]

optim = torch.optim.Adadelta(model.parameters())
with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."):
task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup")
optimizer, scheduler = task.configure_optimizers()

task = ClassificationTask(
model,
optimizer=optim,
scheduler="linear_schedule_with_warmup",
scheduler_kwargs={"num_warmup_steps": 0.1},
loss_fn=F.nll_loss,
)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=2)
ds = DummyDataset()
trainer.fit(task, train_dataloader=DataLoader(ds))
optimizer, scheduler = task.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adadelta)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR)
expected = get_linear_schedule_with_warmup.__name__
assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected

0 comments on commit 79f271e

Please sign in to comment.