Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix ReduceLROnPlateau #1251

Merged
merged 5 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where `InstanceSegmentation` would fail if samples had an inconsistent number of bboxes, labels, and masks (these will now be treated as negative samples) ([#1222](https://github.com/PyTorchLightning/lightning-flash/pull/1222))

- Fixed a bug where using `ReduceLROnPlateau` would raise an error ([#1251](https://github.com/PyTorchLightning/lightning-flash/pull/1251))

## [0.7.0] - 2022-02-15

### Added
Expand Down
197 changes: 119 additions & 78 deletions docs/source/general/optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,42 @@ Setting an optimizer to a task
Each task has a built-in method :func:`~flash.core.model.Task.available_optimizers` which will list all the optimizers
registered with Flash.

.. doctest::

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_optimizers()
['adadelta', ..., 'sgd']
>>> ClassificationTask.available_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
[...'adadelta', ..., 'sgd'...]

To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string.

.. code-block:: python

from flash.image import ImageClassifier
.. doctest::

model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4)
>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4)
>>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Adam ...


In order to customize specific parameters of the Optimizer, pass along a dictionary of kwargs with the string as a tuple.

.. code-block:: python
.. doctest::

from flash.image import ImageClassifier

model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4)
>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4)
>>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Adam ( ... amsgrad: True ...)


An alternative to customizing an optimizer using a tuple is to pass it as a callable.

.. code-block:: python

from functools import partial
from torch.optim import Adam
from flash.image import ImageClassifier
.. doctest::

model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4)
>>> from functools import partial
>>> from torch.optim import Adam
>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4)
>>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Adam ( ... amsgrad: True ...)


Setting a Learning Rate Scheduler
Expand All @@ -54,53 +59,60 @@ Setting a Learning Rate Scheduler
Each task has a built-in method :func:`~flash.core.model.Task.available_lr_schedulers` which will list all the learning
rate schedulers registered with Flash.

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_lr_schedulers()
['lambdalr', ..., 'cosineannealingwarmrestarts']
.. doctest::

To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string.
>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_lr_schedulers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
[...'cosineannealingwarmrestarts', ..., 'lambdalr'...]

.. code-block:: python
To train / finetune a :class:`~flash.core.model.Task` with a scheduler of your choice, just pass in the name:

from flash.image import ImageClassifier
.. doctest::

model = ImageClassifier(
num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule"
)
>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(
... num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule"
... )
>>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
([Adam ...], [{'scheduler': ...}])

.. note:: ``"constant_schedule"`` and a few other lr schedulers will be available only if you have installed the ``transformers`` library from Hugging Face.


In order to customize specific parameters of the LR Scheduler, pass along a dictionary of kwargs with the string as a tuple.

.. code-block:: python

from flash.image import ImageClassifier
.. doctest::

model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=("StepLR", {"step_size": 10}),
)
>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(
... num_classes=10,
... backbone="resnet18",
... optimizer="Adam",
... learning_rate=1e-4,
... lr_scheduler=("StepLR", {"step_size": 10}),
... )
>>> scheduler = model.configure_optimizers()[1][0]["scheduler"]
>>> scheduler.step_size # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
10


An alternative to customizing the LR Scheduler using a tuple is to pass it as a callable.

.. code-block:: python

from functools import partial
from torch.optim.lr_scheduler import CyclicLR
from flash.image import ImageClassifier
.. doctest::

model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=partial(CyclicLR, step_size_up=1500, mode="exp_range", gamma=0.5),
)
>>> from functools import partial
>>> from torch.optim.lr_scheduler import CyclicLR
>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(
... num_classes=10,
... backbone="resnet18",
... optimizer="SGD",
... learning_rate=1e-4,
... lr_scheduler=partial(CyclicLR, base_lr=0.001, max_lr=0.1, mode="exp_range", gamma=0.5),
... )
>>> scheduler = model.configure_optimizers()[1][0]["scheduler"]
>>> (scheduler.mode, scheduler.gamma) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
('exp_range', 0.5)


Additionally, the ``lr_scheduler`` parameter also accepts the Lightning Scheduler configuration which can be passed on using a tuple.
Expand Down Expand Up @@ -136,17 +148,18 @@ When there are schedulers in which the ``.step()`` method is conditioned on a va
Flash requires that the Lightning Scheduler configuration contains the keyword ``"monitor"`` set to the metric name that the scheduler should be conditioned on.
Below is an example for this:

.. code-block:: python

from flash.image import ImageClassifier
.. doctest::

model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}),
)
>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(
... num_classes=10,
... backbone="resnet18",
... optimizer="Adam",
... learning_rate=1e-4,
... lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}),
... )
>>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
([Adam ...], [{'scheduler': ..., 'monitor': 'val_accuracy', ...}])


.. note:: Do not set the ``"scheduler"`` key in the Lightning Scheduler configuration, it will overridden with an instance of the provided scheduler key.
Expand All @@ -159,18 +172,18 @@ Flash registry also provides the flexiblty of registering functions. This featur

Using the ``optimizers`` and ``lr_schedulers`` decorator pertaining to each :class:`~flash.core.model.Task`, custom optimizer and LR scheduler recipes can be pre-registered.

.. code-block:: python

import torch
from flash.image import ImageClassifier
.. doctest::


@ImageClassifier.lr_schedulers
def my_flash_steplr_recipe(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)


model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe")
>>> import torch
>>> from flash.image import ImageClassifier
>>> @ImageClassifier.lr_schedulers
... def my_flash_steplr_recipe(optimizer):
... return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
...
>>> model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe")
>>> scheduler = model.configure_optimizers()[1][0]["scheduler"]
>>> scheduler.step_size # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
10


Provider specific requirements
Expand All @@ -185,13 +198,41 @@ In order to use them in Flash, just provide ``num_warmup_steps`` as float betwee
that will be used as warmup steps. Flash's :class:`~flash.core.trainer.Trainer` will take care of computing the number of training steps and
number of warmup steps based on the flags that are set in the Trainer.

.. code-block:: python

from flash.image import ImageClassifier

model = ImageClassifier(
backbone="resnet18",
num_classes=2,
optimizer="Adam",
lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}),
)
.. testsetup::

import numpy as np
from PIL import Image

rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
_ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]
_ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)]

.. doctest::

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "dog", "cat"],
... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> model = ImageClassifier(
... backbone="resnet18",
... num_classes=datamodule.num_classes,
... optimizer="Adam",
... lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}),
... )
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...

.. testcleanup::

import os

_ = [os.remove(f"image_{i}.png") for i in range(1, 4)]
_ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)]
6 changes: 0 additions & 6 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,12 +755,6 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
# 2) If return value is a dictionary, check for the lr_scheduler_config `only keys` and return the config.
lr_scheduler: Union[_LRScheduler, Dict[str, Any]] = lr_scheduler_fn(optimizer, **lr_scheduler_kwargs)

if not isinstance(lr_scheduler, (_LRScheduler, Dict)):
raise MisconfigurationException(
f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler"
f" configuration with keys belonging to {list(default_scheduler_config.keys())}."
)

if isinstance(lr_scheduler, Dict):
dummy_config = default_scheduler_config
if not all(config_key in dummy_config.keys() for config_key in lr_scheduler.keys()):
Expand Down
15 changes: 0 additions & 15 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -442,18 +441,6 @@ def test_errors_and_exceptions_optimizers_and_schedulers():
task = ClassificationTask(model, optimizer="Adam", lr_scheduler="not_a_valid_key")
task.configure_optimizers()

@ClassificationTask.lr_schedulers
def i_will_create_a_misconfiguration_exception(optimizer):
return "Done. Created."

with pytest.raises(MisconfigurationException):
task = ClassificationTask(model, optimizer="Adam", lr_scheduler="i_will_create_a_misconfiguration_exception")
task.configure_optimizers()

with pytest.raises(MisconfigurationException):
task = ClassificationTask(model, optimizer="Adam", lr_scheduler=i_will_create_a_misconfiguration_exception)
task.configure_optimizers()

with pytest.raises(TypeError):
task = ClassificationTask(model, optimizer="Adam", lr_scheduler=["not", "a", "valid", "type"])
task.configure_optimizers()
Expand All @@ -464,8 +451,6 @@ def i_will_create_a_misconfiguration_exception(optimizer):
)
task.configure_optimizers()

pass


def test_classification_task_metrics():
train_dataset = FixedDataset([0, 1])
Expand Down