diff --git a/CHANGELOG.md b/CHANGELOG.md index f3e57d1114..bb236747b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `ObjectDetector`, `InstanceSegmentation`, and `KeypointDetector` would log train and validation metrics with the same name ([#1252](https://github.com/PyTorchLightning/lightning-flash/pull/1252)) +- Fixed a bug where using `ReduceLROnPlateau` would raise an error ([#1251](https://github.com/PyTorchLightning/lightning-flash/pull/1251)) + ## [0.7.0] - 2022-02-15 ### Added diff --git a/docs/source/general/optimization.rst b/docs/source/general/optimization.rst index 787675b2bf..69e1ff13c0 100644 --- a/docs/source/general/optimization.rst +++ b/docs/source/general/optimization.rst @@ -15,37 +15,42 @@ Setting an optimizer to a task Each task has a built-in method :func:`~flash.core.model.Task.available_optimizers` which will list all the optimizers registered with Flash. +.. doctest:: + >>> from flash.core.classification import ClassificationTask - >>> ClassificationTask.available_optimizers() - ['adadelta', ..., 'sgd'] + >>> ClassificationTask.available_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [...'adadelta', ..., 'sgd'...] To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string. -.. code-block:: python - - from flash.image import ImageClassifier +.. doctest:: - model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4) + >>> from flash.image import ImageClassifier + >>> model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4) + >>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Adam ... In order to customize specific parameters of the Optimizer, pass along a dictionary of kwargs with the string as a tuple. -.. code-block:: python +.. doctest:: - from flash.image import ImageClassifier - - model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4) + >>> from flash.image import ImageClassifier + >>> model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4) + >>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Adam ( ... amsgrad: True ...) An alternative to customizing an optimizer using a tuple is to pass it as a callable. -.. code-block:: python - - from functools import partial - from torch.optim import Adam - from flash.image import ImageClassifier +.. doctest:: - model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4) + >>> from functools import partial + >>> from torch.optim import Adam + >>> from flash.image import ImageClassifier + >>> model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4) + >>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Adam ( ... amsgrad: True ...) Setting a Learning Rate Scheduler @@ -54,53 +59,60 @@ Setting a Learning Rate Scheduler Each task has a built-in method :func:`~flash.core.model.Task.available_lr_schedulers` which will list all the learning rate schedulers registered with Flash. - >>> from flash.core.classification import ClassificationTask - >>> ClassificationTask.available_lr_schedulers() - ['lambdalr', ..., 'cosineannealingwarmrestarts'] +.. doctest:: -To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string. + >>> from flash.core.classification import ClassificationTask + >>> ClassificationTask.available_lr_schedulers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [...'cosineannealingwarmrestarts', ..., 'lambdalr'...] -.. code-block:: python +To train / finetune a :class:`~flash.core.model.Task` with a scheduler of your choice, just pass in the name: - from flash.image import ImageClassifier +.. doctest:: - model = ImageClassifier( - num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule" - ) + >>> from flash.image import ImageClassifier + >>> model = ImageClassifier( + ... num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule" + ... ) + >>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ([Adam ...], [{'scheduler': ...}]) .. note:: ``"constant_schedule"`` and a few other lr schedulers will be available only if you have installed the ``transformers`` library from Hugging Face. In order to customize specific parameters of the LR Scheduler, pass along a dictionary of kwargs with the string as a tuple. -.. code-block:: python - - from flash.image import ImageClassifier +.. doctest:: - model = ImageClassifier( - num_classes=10, - backbone="resnet18", - optimizer="Adam", - learning_rate=1e-4, - lr_scheduler=("StepLR", {"step_size": 10}), - ) + >>> from flash.image import ImageClassifier + >>> model = ImageClassifier( + ... num_classes=10, + ... backbone="resnet18", + ... optimizer="Adam", + ... learning_rate=1e-4, + ... lr_scheduler=("StepLR", {"step_size": 10}), + ... ) + >>> scheduler = model.configure_optimizers()[1][0]["scheduler"] + >>> scheduler.step_size # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + 10 An alternative to customizing the LR Scheduler using a tuple is to pass it as a callable. -.. code-block:: python - - from functools import partial - from torch.optim.lr_scheduler import CyclicLR - from flash.image import ImageClassifier +.. doctest:: - model = ImageClassifier( - num_classes=10, - backbone="resnet18", - optimizer="Adam", - learning_rate=1e-4, - lr_scheduler=partial(CyclicLR, step_size_up=1500, mode="exp_range", gamma=0.5), - ) + >>> from functools import partial + >>> from torch.optim.lr_scheduler import CyclicLR + >>> from flash.image import ImageClassifier + >>> model = ImageClassifier( + ... num_classes=10, + ... backbone="resnet18", + ... optimizer="SGD", + ... learning_rate=1e-4, + ... lr_scheduler=partial(CyclicLR, base_lr=0.001, max_lr=0.1, mode="exp_range", gamma=0.5), + ... ) + >>> scheduler = model.configure_optimizers()[1][0]["scheduler"] + >>> (scheduler.mode, scheduler.gamma) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ('exp_range', 0.5) Additionally, the ``lr_scheduler`` parameter also accepts the Lightning Scheduler configuration which can be passed on using a tuple. @@ -136,17 +148,18 @@ When there are schedulers in which the ``.step()`` method is conditioned on a va Flash requires that the Lightning Scheduler configuration contains the keyword ``"monitor"`` set to the metric name that the scheduler should be conditioned on. Below is an example for this: -.. code-block:: python - - from flash.image import ImageClassifier +.. doctest:: - model = ImageClassifier( - num_classes=10, - backbone="resnet18", - optimizer="Adam", - learning_rate=1e-4, - lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}), - ) + >>> from flash.image import ImageClassifier + >>> model = ImageClassifier( + ... num_classes=10, + ... backbone="resnet18", + ... optimizer="Adam", + ... learning_rate=1e-4, + ... lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}), + ... ) + >>> model.configure_optimizers() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ([Adam ...], [{'scheduler': ..., 'monitor': 'val_accuracy', ...}]) .. note:: Do not set the ``"scheduler"`` key in the Lightning Scheduler configuration, it will overridden with an instance of the provided scheduler key. @@ -159,18 +172,18 @@ Flash registry also provides the flexiblty of registering functions. This featur Using the ``optimizers`` and ``lr_schedulers`` decorator pertaining to each :class:`~flash.core.model.Task`, custom optimizer and LR scheduler recipes can be pre-registered. -.. code-block:: python - - import torch - from flash.image import ImageClassifier +.. doctest:: - - @ImageClassifier.lr_schedulers - def my_flash_steplr_recipe(optimizer): - return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) - - - model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe") + >>> import torch + >>> from flash.image import ImageClassifier + >>> @ImageClassifier.lr_schedulers + ... def my_flash_steplr_recipe(optimizer): + ... return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) + ... + >>> model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe") + >>> scheduler = model.configure_optimizers()[1][0]["scheduler"] + >>> scheduler.step_size # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + 10 Provider specific requirements @@ -185,13 +198,41 @@ In order to use them in Flash, just provide ``num_warmup_steps`` as float betwee that will be used as warmup steps. Flash's :class:`~flash.core.trainer.Trainer` will take care of computing the number of training steps and number of warmup steps based on the flags that are set in the Trainer. -.. code-block:: python - - from flash.image import ImageClassifier - - model = ImageClassifier( - backbone="resnet18", - num_classes=2, - optimizer="Adam", - lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}), - ) +.. testsetup:: + + import numpy as np + from PIL import Image + + rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)] + _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)] + +.. doctest:: + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=["cat", "dog", "cat"], + ... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> model = ImageClassifier( + ... backbone="resnet18", + ... num_classes=datamodule.num_classes, + ... optimizer="Adam", + ... lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}), + ... ) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + +.. testcleanup:: + + import os + + _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] + _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] diff --git a/flash/core/model.py b/flash/core/model.py index 771e403b52..20403b7b2a 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -774,12 +774,6 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: # 2) If return value is a dictionary, check for the lr_scheduler_config `only keys` and return the config. lr_scheduler: Union[_LRScheduler, Dict[str, Any]] = lr_scheduler_fn(optimizer, **lr_scheduler_kwargs) - if not isinstance(lr_scheduler, (_LRScheduler, Dict)): - raise MisconfigurationException( - f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler" - f" configuration with keys belonging to {list(default_scheduler_config.keys())}." - ) - if isinstance(lr_scheduler, Dict): dummy_config = default_scheduler_config if not all(config_key in dummy_config.keys() for config_key in lr_scheduler.keys()): diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 6a236aa6aa..4432b48274 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -24,7 +24,6 @@ import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor from torch.nn import functional as F from torch.utils.data import DataLoader @@ -443,18 +442,6 @@ def test_errors_and_exceptions_optimizers_and_schedulers(): task = ClassificationTask(model, optimizer="Adam", lr_scheduler="not_a_valid_key") task.configure_optimizers() - @ClassificationTask.lr_schedulers - def i_will_create_a_misconfiguration_exception(optimizer): - return "Done. Created." - - with pytest.raises(MisconfigurationException): - task = ClassificationTask(model, optimizer="Adam", lr_scheduler="i_will_create_a_misconfiguration_exception") - task.configure_optimizers() - - with pytest.raises(MisconfigurationException): - task = ClassificationTask(model, optimizer="Adam", lr_scheduler=i_will_create_a_misconfiguration_exception) - task.configure_optimizers() - with pytest.raises(TypeError): task = ClassificationTask(model, optimizer="Adam", lr_scheduler=["not", "a", "valid", "type"]) task.configure_optimizers() @@ -465,8 +452,6 @@ def i_will_create_a_misconfiguration_exception(optimizer): ) task.configure_optimizers() - pass - def test_classification_task_metrics(): train_dataset = FixedDataset([0, 1])