From 3d557bd62b9ee4580cee7eb4c686f13db62b7cdb Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 13:36:16 +0000 Subject: [PATCH 01/15] improve finetuning --- flash/core/finetuning.py | 153 ++++++++++++++++++ flash/core/model.py | 6 +- flash/core/trainer.py | 95 ++++------- .../finetuning/image_classification.py | 4 +- .../finetuning/text_classification.py | 2 +- tests/core/test_finetuning.py | 41 +++++ 6 files changed, 237 insertions(+), 64 deletions(-) create mode 100644 flash/core/finetuning.py create mode 100644 tests/core/test_finetuning.py diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py new file mode 100644 index 00000000000..11fdb0b407b --- /dev/null +++ b/flash/core/finetuning.py @@ -0,0 +1,153 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import List, Union + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import BaseFinetuning +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn +from torch.optim import Optimizer + +_EXCLUDE_PARAMTERS = ["self", "args", "kwargs"] + + +class NeverFreeze(BaseFinetuning): + pass + + +class NeverUnFreeze(BaseFinetuning): + + def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): + self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names + self.train_bn = train_bn + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + for attr_name in self.attr_names: + attr = getattr(pl_module, attr_name, None) + if attr is None or not isinstance(attr, nn.Module): + MisconfigurationException("To use NeverUnFreeze your model must have a {attr} attribute") + self.freeze(module=attr, train_bn=self.train_bn) + + +class FreezeUnFreeze(NeverUnFreeze): + + def __init__( + self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_at_epoch: int = 10 + ): + super().__init__(attr_names, train_bn) + self.unfreeze_at_epoch = unfreeze_at_epoch + + def finetunning_function( + self, + pl_module: pl.LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + if epoch == self.unfreeze_at_epoch: + modules = [] + for attr_name in self.attr_names: + modules.append(getattr(pl_module, attr_name)) + + self.unfreeze_and_add_param_group( + module=modules, + optimizer=optimizer, + train_bn=self.train_bn, + ) + + +# NOTE: copied from: +# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66 +class MilestonesFinetuning(BaseFinetuning): + + def __init__(self, unfreeze_milestones: tuple = (5, 10), train_bn: bool = True, num_layers: int = 5): + self.unfreeze_milestones = unfreeze_milestones + self.train_bn = train_bn + self.num_layers = num_layers + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + # TODO: might need some config to say which attribute is model + # maybe something like: + # self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn) + # where self.feature_attr can be "backbone" or "feature_extractor", etc. + # (configured in init) + assert hasattr(pl_module, "backbone"), "To use MilestonesFinetuning your model must have a backbone attribute" + self.freeze(module=pl_module.backbone, train_bn=self.train_bn) + + def finetunning_function( + self, + pl_module: pl.LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + backbone_modules = list(pl_module.backbone.modules()) + if epoch == self.unfreeze_milestones[0]: + # unfreeze 5 last layers + # TODO last N layers should be parameter + self.unfreeze_and_add_param_group( + module=backbone_modules[-self.num_layers:], + optimizer=optimizer, + train_bn=self.train_bn, + ) + + elif epoch == self.unfreeze_milestones[1]: + # unfreeze remaining layers + # TODO last N layers should be parameter + self.unfreeze_and_add_param_group( + module=backbone_modules[:-self.num_layers], + optimizer=optimizer, + train_bn=self.train_bn, + ) + + +def instantiate_cls(cls, kwargs): + parameters = list(inspect.signature(cls.__init__).parameters.keys()) + parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS] + cls_kwargs = {} + for p in parameters: + if p in kwargs: + cls_kwargs[p] = kwargs.pop(p) + if len(kwargs) > 0: + raise MisconfigurationException(f"Available parameters are: {parameters}. Found {kwargs} left") + return cls(**cls_kwargs) + + +_DEFAULTS_FINETUNE_STRATEGIES = { + "never_freeze": NeverFreeze, + "never_unfreeze": NeverUnFreeze, + "freeze_unfreeze": FreezeUnFreeze, + "unfreeze_milestones": MilestonesFinetuning +} + + +def instantiate_default_finetuning_callbacks(kwargs): + finetune_strategy = kwargs.pop("finetune_strategy", None) + if isinstance(finetune_strategy, str): + finetune_strategy = finetune_strategy.lower() + if finetune_strategy in _DEFAULTS_FINETUNE_STRATEGIES: + return [instantiate_cls(_DEFAULTS_FINETUNE_STRATEGIES[finetune_strategy], kwargs)] + else: + msg = "\n Extra arguments can be: \n" + for n, cls in _DEFAULTS_FINETUNE_STRATEGIES.items(): + parameters = list(inspect.signature(cls.__init__).parameters.keys()) + parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS] + msg += f"{n}: {parameters} \n" + raise MisconfigurationException( + f"finetune_strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" + f"{msg}" + f". Found {finetune_strategy}" + ) + return [] diff --git a/flash/core/model.py b/flash/core/model.py index 51b1a87d121..0bc675dd14c 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl import torch from torch import nn from flash.core.data import DataModule, DataPipeline +from flash.core.finetuning import instantiate_default_finetuning_callbacks from flash.core.utils import get_callable_dict @@ -150,3 +151,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint["pipeline"] = self.data_pipeline + + def configure_finetune_callbacks(self, **kwargs) -> List: + return instantiate_default_finetuning_callbacks(kwargs) diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 7ce9f8cf75d..ebfd1a20427 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -15,55 +15,11 @@ from typing import List, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.callbacks import BaseFinetuning -from torch.optim import Optimizer +from pytorch_lightning.callbacks import BaseFinetuning, Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader - -# NOTE: copied from: -# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66 -class MilestonesFinetuningCallback(BaseFinetuning): - - def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True): - self.milestones = milestones - self.train_bn = train_bn - - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - # TODO: might need some config to say which attribute is model - # maybe something like: - # self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn) - # where self.feature_attr can be "backbone" or "feature_extractor", etc. - # (configured in init) - assert hasattr( - pl_module, "backbone" - ), "To use MilestonesFinetuningCallback your model must have a backbone attribute" - self.freeze(module=pl_module.backbone, train_bn=self.train_bn) - - def finetunning_function( - self, - pl_module: pl.LightningModule, - epoch: int, - optimizer: Optimizer, - opt_idx: int, - ) -> None: - backbone_modules = list(pl_module.backbone.modules()) - if epoch == self.milestones[0]: - # unfreeze 5 last layers - # TODO last N layers should be parameter - self.unfreeze_and_add_param_group( - module=backbone_modules[-5:], - optimizer=optimizer, - train_bn=self.train_bn, - ) - - elif epoch == self.milestones[1]: - # unfreeze remaining layers - # TODO last N layers should be parameter - self.unfreeze_and_add_param_group( - module=backbone_modules[:-5], - optimizer=optimizer, - train_bn=self.train_bn, - ) +from flash.core.model import Task class Trainer(pl.Trainer): @@ -96,11 +52,12 @@ def fit( def finetune( self, - model: pl.LightningModule, + model: Task, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[pl.LightningDataModule] = None, - unfreeze_milestones: tuple = (5, 10), + finetune_strategy: Optional[Union[str, Callback]] = None, + **callbacks_kwargs, ): r""" Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers @@ -117,18 +74,36 @@ def finetune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - unfreeze_milestones: A tuple of two integers. First value marks the epoch in which the last 5 - layers of the backbone will be unfrozen. The second value marks the epoch in which the full backbone will - be unfrozen. + finetune_strategy: Should either be a string or a finetuning callback subclassing + ``pytorch_lightning.callbacks.BaseFinetuning``. + + callbacks_kwargs: Those arguments will be provided to `model.configure_finetune_callbacks` + to instantiante your own finetuning callbacks. """ - if hasattr(model, "backbone"): - # TODO: if we find a finetuning callback in the trainer should we change it? - # or should we warn the user? - if not any(isinstance(c, BaseFinetuning) for c in self.callbacks): - # TODO: should pass config from arguments - self.callbacks.append(MilestonesFinetuningCallback(milestones=unfreeze_milestones)) - else: - warnings.warn("Warning: model does not have a 'backbone' attribute, will train normally") + if isinstance(finetune_strategy, Callback) and not isinstance(finetune_strategy, BaseFinetuning): + raise Exception("finetune_strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback") + self._resolve_callbacks(model, finetune_strategy, **callbacks_kwargs) return super().fit(model, train_dataloader, val_dataloaders, datamodule) + + def _resolve_callbacks(self, model, finetune_strategy, **callbacks_kwargs): + if sum((isinstance(c, BaseFinetuning) for c in [finetune_strategy])) > 1: + raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.") + # provided callbacks are higher priorities than model callbacks. + callbacks = self.callbacks + if isinstance(finetune_strategy, str): + callbacks_kwargs["finetune_strategy"] = finetune_strategy + else: + callbacks = self._merge_callbacks(callbacks, [finetune_strategy]) + self.callbacks = self._merge_callbacks(callbacks, model.configure_finetune_callbacks(**callbacks_kwargs)) + + @staticmethod + def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: + if len(new_callbacks): + return current_callbacks + new_callbacks_types = set(type(c) for c in new_callbacks) + current_callbacks_types = set(type(c) for c in current_callbacks) + override_types = new_callbacks_types.intersection(current_callbacks_types) + new_callbacks.extend(c for c in current_callbacks if type(c) not in override_types) + return new_callbacks diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index ef8dac9aa81..3e64a4867c9 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -18,10 +18,10 @@ model = ImageClassifier(num_classes=datamodule.num_classes) # 4. Create the trainer. Run once on data - trainer = flash.Trainer(max_epochs=1) + trainer = flash.Trainer(max_epochs=2) # 5. Train the model - trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) + trainer.finetune(model, datamodule=datamodule, finetune_strategy='freeze_unfreeze', unfreeze_at_epoch=1) # 6. Test the model trainer.test() diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index 2c0ff4b3f9f..445cbb43288 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -24,7 +24,7 @@ trainer = flash.Trainer(max_epochs=1) # 5. Fine-tune the model - trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) + trainer.finetune(model, datamodule=datamodule, finetune_strategy='never_freeze') # 6. Test model trainer.test() diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py new file mode 100644 index 00000000000..de49b77f7bb --- /dev/null +++ b/tests/core/test_finetuning.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn +from torch.nn import functional as F + +from flash import ClassificationTask, Trainer +from flash.core.finetuning import NeverFreeze +from tests.core.test_model import DummyDataset + + +@pytest.mark.parametrize( + "finetune_strategy", + ['never_freeze', 'never_unfreeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] +) +def test_finetuning(tmpdir: str, finetune_strategy): + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + val_dl = torch.utils.data.DataLoader(DummyDataset()) + task = ClassificationTask(model, F.nll_loss) + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + if finetune_strategy == "cls": + finetune_strategy = NeverFreeze() + if finetune_strategy == 'chocolat': + with pytest.raises(MisconfigurationException, match="finetune_strategy should be within"): + trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) + else: + trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) From 72421021bb8eae7af114a1fc429228c8c4bda8a3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 13:39:00 +0000 Subject: [PATCH 02/15] update changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a589789b76..57ab7ffc1f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) + +- Added `3 BaseFinetuning Callbacks` and `configure_finetuning_callbacks` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) + ### Changed ### Fixed -### Removed \ No newline at end of file +### Removed From fe83588acfc2d64d82087bd9a772cae01ef85149 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 14:19:26 +0000 Subject: [PATCH 03/15] update on comments --- flash/core/finetuning.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 11fdb0b407b..cc84ff1c445 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -20,14 +20,14 @@ from torch import nn from torch.optim import Optimizer -_EXCLUDE_PARAMTERS = ["self", "args", "kwargs"] +_EXCLUDE_PARAMTERS = ("self", "args", "kwargs") class NeverFreeze(BaseFinetuning): pass -class NeverUnFreeze(BaseFinetuning): +class NeverUnfreeze(BaseFinetuning): def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names @@ -37,11 +37,11 @@ def freeze_before_training(self, pl_module: pl.LightningModule) -> None: for attr_name in self.attr_names: attr = getattr(pl_module, attr_name, None) if attr is None or not isinstance(attr, nn.Module): - MisconfigurationException("To use NeverUnFreeze your model must have a {attr} attribute") + MisconfigurationException("To use NeverUnfreeze your model must have a {attr} attribute") self.freeze(module=attr, train_bn=self.train_bn) -class FreezeUnFreeze(NeverUnFreeze): +class FreezeUnFreeze(NeverUnfreeze): def __init__( self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_at_epoch: int = 10 @@ -68,8 +68,6 @@ def finetunning_function( ) -# NOTE: copied from: -# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66 class MilestonesFinetuning(BaseFinetuning): def __init__(self, unfreeze_milestones: tuple = (5, 10), train_bn: bool = True, num_layers: int = 5): @@ -127,7 +125,7 @@ def instantiate_cls(cls, kwargs): _DEFAULTS_FINETUNE_STRATEGIES = { "never_freeze": NeverFreeze, - "never_unfreeze": NeverUnFreeze, + "never_unfreeze": NeverUnfreeze, "freeze_unfreeze": FreezeUnFreeze, "unfreeze_milestones": MilestonesFinetuning } From ed4cf423a3944857be2a5450b73176f22d3a814b Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 14:20:08 +0000 Subject: [PATCH 04/15] typo --- flash/core/finetuning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index cc84ff1c445..7ac2df6271f 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -41,7 +41,7 @@ def freeze_before_training(self, pl_module: pl.LightningModule) -> None: self.freeze(module=attr, train_bn=self.train_bn) -class FreezeUnFreeze(NeverUnfreeze): +class FreezeUnfreeze(NeverUnfreeze): def __init__( self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_at_epoch: int = 10 @@ -126,7 +126,7 @@ def instantiate_cls(cls, kwargs): _DEFAULTS_FINETUNE_STRATEGIES = { "never_freeze": NeverFreeze, "never_unfreeze": NeverUnfreeze, - "freeze_unfreeze": FreezeUnFreeze, + "freeze_unfreeze": FreezeUnfreeze, "unfreeze_milestones": MilestonesFinetuning } From 80d2b96279a7f7ed940abc72aef5ff7ce7191027 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:05:10 +0000 Subject: [PATCH 05/15] update on comments --- CHANGELOG.md | 4 +- flash/core/finetuning.py | 83 +++++++++---------- flash/core/model.py | 4 - flash/core/trainer.py | 28 ++++--- .../finetuning/image_classification.py | 2 +- tests/core/test_finetuning.py | 10 ++- 6 files changed, 62 insertions(+), 69 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57ab7ffc1f9..5d0695869b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) -- Added `3 BaseFinetuning Callbacks` and `configure_finetuning_callbacks` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) +- Added `configure_finetuning_callbacks` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) + ### Changed + ### Fixed diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 7ac2df6271f..9f9e55e09bf 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -23,31 +23,40 @@ _EXCLUDE_PARAMTERS = ("self", "args", "kwargs") -class NeverFreeze(BaseFinetuning): +class NoFreeze(BaseFinetuning): pass -class NeverUnfreeze(BaseFinetuning): +def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = True): + for attr_name in attr_names: + attr = getattr(pl_module, attr_name, None) + if attr is None or not isinstance(attr, nn.Module): + MisconfigurationException("To use Freeze your model must have a {attr} attribute") + BaseFinetuning.freeze(module=attr, train_bn=train_bn) + + +class FlashBaseBaseFinetuning(BaseFinetuning): def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - for attr_name in self.attr_names: - attr = getattr(pl_module, attr_name, None) - if attr is None or not isinstance(attr, nn.Module): - MisconfigurationException("To use NeverUnfreeze your model must have a {attr} attribute") - self.freeze(module=attr, train_bn=self.train_bn) + freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) -class FreezeUnfreeze(NeverUnfreeze): +class Freeze(FlashBaseBaseFinetuning): + pass - def __init__( - self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_at_epoch: int = 10 - ): + +class FreezeUnfreeze(FlashBaseBaseFinetuning): + + def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_epoch: int = 10): super().__init__(attr_names, train_bn) - self.unfreeze_at_epoch = unfreeze_at_epoch + self.unfreeze_epoch = unfreeze_epoch + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) def finetunning_function( self, @@ -56,7 +65,7 @@ def finetunning_function( optimizer: Optimizer, opt_idx: int, ) -> None: - if epoch == self.unfreeze_at_epoch: + if epoch == self.unfreeze_epoch: modules = [] for attr_name in self.attr_names: modules.append(getattr(pl_module, attr_name)) @@ -68,21 +77,22 @@ def finetunning_function( ) -class MilestonesFinetuning(BaseFinetuning): +class MilestonesFinetuning(FlashBaseBaseFinetuning): - def __init__(self, unfreeze_milestones: tuple = (5, 10), train_bn: bool = True, num_layers: int = 5): + def __init__( + self, + attr_names: Union[str, List[str]] = "backbone", + train_bn: bool = True, + unfreeze_milestones: tuple = (5, 10), + num_layers: int = 5 + ): self.unfreeze_milestones = unfreeze_milestones - self.train_bn = train_bn self.num_layers = num_layers + super().__init__(attr_names, train_bn) + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - # TODO: might need some config to say which attribute is model - # maybe something like: - # self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn) - # where self.feature_attr can be "backbone" or "feature_extractor", etc. - # (configured in init) - assert hasattr(pl_module, "backbone"), "To use MilestonesFinetuning your model must have a backbone attribute" - self.freeze(module=pl_module.backbone, train_bn=self.train_bn) + freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) def finetunning_function( self, @@ -111,41 +121,22 @@ def finetunning_function( ) -def instantiate_cls(cls, kwargs): - parameters = list(inspect.signature(cls.__init__).parameters.keys()) - parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS] - cls_kwargs = {} - for p in parameters: - if p in kwargs: - cls_kwargs[p] = kwargs.pop(p) - if len(kwargs) > 0: - raise MisconfigurationException(f"Available parameters are: {parameters}. Found {kwargs} left") - return cls(**cls_kwargs) - - _DEFAULTS_FINETUNE_STRATEGIES = { - "never_freeze": NeverFreeze, - "never_unfreeze": NeverUnfreeze, + "no_freeze": NoFreeze, + "freeze": Freeze, "freeze_unfreeze": FreezeUnfreeze, "unfreeze_milestones": MilestonesFinetuning } -def instantiate_default_finetuning_callbacks(kwargs): - finetune_strategy = kwargs.pop("finetune_strategy", None) +def instantiate_default_finetuning_callbacks(finetune_strategy): if isinstance(finetune_strategy, str): finetune_strategy = finetune_strategy.lower() if finetune_strategy in _DEFAULTS_FINETUNE_STRATEGIES: - return [instantiate_cls(_DEFAULTS_FINETUNE_STRATEGIES[finetune_strategy], kwargs)] + return [_DEFAULTS_FINETUNE_STRATEGIES[finetune_strategy]()] else: - msg = "\n Extra arguments can be: \n" - for n, cls in _DEFAULTS_FINETUNE_STRATEGIES.items(): - parameters = list(inspect.signature(cls.__init__).parameters.keys()) - parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS] - msg += f"{n}: {parameters} \n" raise MisconfigurationException( f"finetune_strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" - f"{msg}" f". Found {finetune_strategy}" ) return [] diff --git a/flash/core/model.py b/flash/core/model.py index 0bc675dd14c..40b7f28be72 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -18,7 +18,6 @@ from torch import nn from flash.core.data import DataModule, DataPipeline -from flash.core.finetuning import instantiate_default_finetuning_callbacks from flash.core.utils import get_callable_dict @@ -151,6 +150,3 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint["pipeline"] = self.data_pipeline - - def configure_finetune_callbacks(self, **kwargs) -> List: - return instantiate_default_finetuning_callbacks(kwargs) diff --git a/flash/core/trainer.py b/flash/core/trainer.py index ebfd1a20427..9e7fecb57c3 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -19,6 +19,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader +from flash.core.finetuning import instantiate_default_finetuning_callbacks from flash.core.model import Task @@ -57,7 +58,6 @@ def finetune( val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[pl.LightningDataModule] = None, finetune_strategy: Optional[Union[str, Callback]] = None, - **callbacks_kwargs, ): r""" Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers @@ -76,27 +76,29 @@ def finetune( finetune_strategy: Should either be a string or a finetuning callback subclassing ``pytorch_lightning.callbacks.BaseFinetuning``. - - callbacks_kwargs: Those arguments will be provided to `model.configure_finetune_callbacks` - to instantiante your own finetuning callbacks. + Currently default strategies can be create with strings such as: + * ``no_freeze``, + * ``freeze`` + * ``freeze_unfreeze`` + * ``unfreeze_milestones`` """ - if isinstance(finetune_strategy, Callback) and not isinstance(finetune_strategy, BaseFinetuning): - raise Exception("finetune_strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback") + if not isinstance(finetune_strategy, (BaseFinetuning, str)): + raise MisconfigurationException( + "finetune_strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback or a str" + ) - self._resolve_callbacks(model, finetune_strategy, **callbacks_kwargs) + self._resolve_callbacks(finetune_strategy) return super().fit(model, train_dataloader, val_dataloaders, datamodule) - def _resolve_callbacks(self, model, finetune_strategy, **callbacks_kwargs): + def _resolve_callbacks(self, finetune_strategy): if sum((isinstance(c, BaseFinetuning) for c in [finetune_strategy])) > 1: raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.") - # provided callbacks are higher priorities than model callbacks. + # todo: change to ``configure_callbacks`` when callbacks = self.callbacks if isinstance(finetune_strategy, str): - callbacks_kwargs["finetune_strategy"] = finetune_strategy - else: - callbacks = self._merge_callbacks(callbacks, [finetune_strategy]) - self.callbacks = self._merge_callbacks(callbacks, model.configure_finetune_callbacks(**callbacks_kwargs)) + finetune_strategy = instantiate_default_finetuning_callbacks(finetune_strategy) + self.callbacks = self._merge_callbacks(callbacks, [finetune_strategy]) @staticmethod def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 3e64a4867c9..275d2d58ea1 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -21,7 +21,7 @@ trainer = flash.Trainer(max_epochs=2) # 5. Train the model - trainer.finetune(model, datamodule=datamodule, finetune_strategy='freeze_unfreeze', unfreeze_at_epoch=1) + trainer.finetune(model, datamodule=datamodule, finetune_strategy='freeze_unfreeze') # 6. Test the model trainer.test() diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index de49b77f7bb..91667985a67 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -18,13 +18,12 @@ from torch.nn import functional as F from flash import ClassificationTask, Trainer -from flash.core.finetuning import NeverFreeze +from flash.core.finetuning import NoFreeze from tests.core.test_model import DummyDataset @pytest.mark.parametrize( - "finetune_strategy", - ['never_freeze', 'never_unfreeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] + "finetune_strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] ) def test_finetuning(tmpdir: str, finetune_strategy): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) @@ -33,9 +32,12 @@ def test_finetuning(tmpdir: str, finetune_strategy): task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) if finetune_strategy == "cls": - finetune_strategy = NeverFreeze() + finetune_strategy = NoFreeze() if finetune_strategy == 'chocolat': with pytest.raises(MisconfigurationException, match="finetune_strategy should be within"): trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) + elif finetune_strategy is None: + with pytest.raises(MisconfigurationException, match="finetune_strategy should"): + trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) else: trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) From d17df45b12b7e55ea216c7aff0dc7b9312c48419 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:09:51 +0000 Subject: [PATCH 06/15] update on comments --- flash/core/finetuning.py | 14 ++++++------ flash/core/trainer.py | 22 +++++++++---------- .../finetuning/image_classification.py | 3 ++- .../finetuning/text_classification.py | 2 +- tests/core/test_finetuning.py | 22 +++++++++---------- 5 files changed, 32 insertions(+), 31 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 9f9e55e09bf..f49356d8e85 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -129,14 +129,14 @@ def finetunning_function( } -def instantiate_default_finetuning_callbacks(finetune_strategy): - if isinstance(finetune_strategy, str): - finetune_strategy = finetune_strategy.lower() - if finetune_strategy in _DEFAULTS_FINETUNE_STRATEGIES: - return [_DEFAULTS_FINETUNE_STRATEGIES[finetune_strategy]()] +def instantiate_default_finetuning_callbacks(strategy): + if isinstance(strategy, str): + strategy = strategy.lower() + if strategy in _DEFAULTS_FINETUNE_STRATEGIES: + return [_DEFAULTS_FINETUNE_STRATEGIES[strategy]()] else: raise MisconfigurationException( - f"finetune_strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" - f". Found {finetune_strategy}" + f"strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" + f". Found {strategy}" ) return [] diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 9e7fecb57c3..9a1491f5372 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -57,7 +57,7 @@ def finetune( train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[pl.LightningDataModule] = None, - finetune_strategy: Optional[Union[str, Callback]] = None, + strategy: Optional[Union[str, BaseFinetuning]] = None, ): r""" Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers @@ -74,7 +74,7 @@ def finetune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - finetune_strategy: Should either be a string or a finetuning callback subclassing + strategy: Should either be a string or a finetuning callback subclassing ``pytorch_lightning.callbacks.BaseFinetuning``. Currently default strategies can be create with strings such as: * ``no_freeze``, @@ -83,22 +83,22 @@ def finetune( * ``unfreeze_milestones`` """ - if not isinstance(finetune_strategy, (BaseFinetuning, str)): + if not isinstance(strategy, (BaseFinetuning, str)): raise MisconfigurationException( - "finetune_strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback or a str" + "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback or a str" ) - self._resolve_callbacks(finetune_strategy) + self._resolve_callbacks(strategy) return super().fit(model, train_dataloader, val_dataloaders, datamodule) - def _resolve_callbacks(self, finetune_strategy): - if sum((isinstance(c, BaseFinetuning) for c in [finetune_strategy])) > 1: + def _resolve_callbacks(self, strategy): + if sum((isinstance(c, BaseFinetuning) for c in [strategy])) > 1: raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.") - # todo: change to ``configure_callbacks`` when + # todo: change to ``configure_callbacks`` when merged to Lightning. callbacks = self.callbacks - if isinstance(finetune_strategy, str): - finetune_strategy = instantiate_default_finetuning_callbacks(finetune_strategy) - self.callbacks = self._merge_callbacks(callbacks, [finetune_strategy]) + if isinstance(strategy, str): + strategy = instantiate_default_finetuning_callbacks(strategy) + self.callbacks = self._merge_callbacks(callbacks, [strategy]) @staticmethod def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 275d2d58ea1..7643e851311 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -1,5 +1,6 @@ import flash from flash.core.data import download_data +from flash.core.finetuning import FreezeUnfreeze from flash.vision import ImageClassificationData, ImageClassifier if __name__ == "__main__": @@ -21,7 +22,7 @@ trainer = flash.Trainer(max_epochs=2) # 5. Train the model - trainer.finetune(model, datamodule=datamodule, finetune_strategy='freeze_unfreeze') + trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) # 6. Test the model trainer.test() diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index 445cbb43288..dd07f46bf90 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -24,7 +24,7 @@ trainer = flash.Trainer(max_epochs=1) # 5. Fine-tune the model - trainer.finetune(model, datamodule=datamodule, finetune_strategy='never_freeze') + trainer.finetune(model, datamodule=datamodule, strategy='freeze') # 6. Test model trainer.test() diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index 91667985a67..1e463855849 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -23,21 +23,21 @@ @pytest.mark.parametrize( - "finetune_strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] + "strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] ) -def test_finetuning(tmpdir: str, finetune_strategy): +def test_finetuning(tmpdir: str, strategy): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) - if finetune_strategy == "cls": - finetune_strategy = NoFreeze() - if finetune_strategy == 'chocolat': - with pytest.raises(MisconfigurationException, match="finetune_strategy should be within"): - trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) - elif finetune_strategy is None: - with pytest.raises(MisconfigurationException, match="finetune_strategy should"): - trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) + if strategy == "cls": + strategy = NoFreeze() + if strategy == 'chocolat': + with pytest.raises(MisconfigurationException, match="strategy should be within"): + trainer.finetune(task, train_dl, val_dl, strategy=strategy) + elif strategy is None: + with pytest.raises(MisconfigurationException, match="strategy should"): + trainer.finetune(task, train_dl, val_dl, strategy=strategy) else: - trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) + trainer.finetune(task, train_dl, val_dl, strategy=strategy) From b4bffaf5e0a1b59b4c675a974d23e003ba2e351d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:16:21 +0000 Subject: [PATCH 07/15] update finetuning --- CHANGELOG.md | 2 +- flash/core/finetuning.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d0695869b9..feaf4c5f64e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) -- Added `configure_finetuning_callbacks` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) +- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) ### Changed diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index f49356d8e85..e3ed6701ae5 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -77,7 +77,7 @@ def finetunning_function( ) -class MilestonesFinetuning(FlashBaseBaseFinetuning): +class UnfreezeMilestones(FlashBaseBaseFinetuning): def __init__( self, @@ -125,7 +125,7 @@ def finetunning_function( "no_freeze": NoFreeze, "freeze": Freeze, "freeze_unfreeze": FreezeUnfreeze, - "unfreeze_milestones": MilestonesFinetuning + "unfreeze_milestones": UnfreezeMilestones } From 2b955f1fda5a58a2fe627a0f2763e1f7c8237e6e Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:17:08 +0000 Subject: [PATCH 08/15] typo --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index feaf4c5f64e..dd92c849c0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) -- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) +- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) ### Changed From cb6a905de6b8f9dbdcb279e6396c9556c853756c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:19:05 +0000 Subject: [PATCH 09/15] update --- flash/core/finetuning.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index e3ed6701ae5..ff569a78354 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -27,14 +27,6 @@ class NoFreeze(BaseFinetuning): pass -def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = True): - for attr_name in attr_names: - attr = getattr(pl_module, attr_name, None) - if attr is None or not isinstance(attr, nn.Module): - MisconfigurationException("To use Freeze your model must have a {attr} attribute") - BaseFinetuning.freeze(module=attr, train_bn=train_bn) - - class FlashBaseBaseFinetuning(BaseFinetuning): def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): @@ -42,7 +34,15 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) + self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) + + @staticmethod + def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = True): + for attr_name in attr_names: + attr = getattr(pl_module, attr_name, None) + if attr is None or not isinstance(attr, nn.Module): + MisconfigurationException("To use Freeze your model must have a {attr} attribute") + BaseFinetuning.freeze(module=attr, train_bn=train_bn) class Freeze(FlashBaseBaseFinetuning): @@ -56,7 +56,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo self.unfreeze_epoch = unfreeze_epoch def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) + self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) def finetunning_function( self, @@ -92,7 +92,7 @@ def __init__( super().__init__(attr_names, train_bn) def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) + self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) def finetunning_function( self, From 6ab9daf7e7b83872fb0eb8828d3279ca49285313 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 17:01:40 +0000 Subject: [PATCH 10/15] update --- CHANGELOG.md | 4 +-- flash/core/finetuning.py | 44 ++++++++++++++++----------- flash/core/model.py | 5 +++- flash/core/trainer.py | 56 +++++++++++++++++++++++------------ tests/core/test_finetuning.py | 7 ++--- tests/core/test_trainer.py | 2 +- 6 files changed, 72 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd92c849c0e..87b983c22e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) +- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/lightning-flash/pull/9)) -- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) +- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/lightning-flash/pull/39)) ### Changed diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index ff569a78354..5e6298b03b0 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -16,6 +16,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import BaseFinetuning +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.optim import Optimizer @@ -23,13 +24,22 @@ _EXCLUDE_PARAMTERS = ("self", "args", "kwargs") -class NoFreeze(BaseFinetuning): - pass +class FlashBaseFinetuning(BaseFinetuning): + def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): + r""" -class FlashBaseBaseFinetuning(BaseFinetuning): + FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback. + + Override ``finetunning_function`` to put your unfreeze logic. + + Args: + attr_names: Name(s) of the module attributes of the model to be frozen. + + train_bn: Wether to train Batch Norm layer + + """ - def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names self.train_bn = train_bn @@ -41,15 +51,11 @@ def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = T for attr_name in attr_names: attr = getattr(pl_module, attr_name, None) if attr is None or not isinstance(attr, nn.Module): - MisconfigurationException("To use Freeze your model must have a {attr} attribute") + MisconfigurationException(f"Your model must have a {attr} attribute") BaseFinetuning.freeze(module=attr, train_bn=train_bn) -class Freeze(FlashBaseBaseFinetuning): - pass - - -class FreezeUnfreeze(FlashBaseBaseFinetuning): +class FreezeUnfreeze(FlashBaseFinetuning): def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_epoch: int = 10): super().__init__(attr_names, train_bn) @@ -77,7 +83,7 @@ def finetunning_function( ) -class UnfreezeMilestones(FlashBaseBaseFinetuning): +class UnfreezeMilestones(FlashBaseFinetuning): def __init__( self, @@ -122,21 +128,23 @@ def finetunning_function( _DEFAULTS_FINETUNE_STRATEGIES = { - "no_freeze": NoFreeze, - "freeze": Freeze, + "no_freeze": BaseFinetuning, + "freeze": FlashBaseFinetuning, "freeze_unfreeze": FreezeUnfreeze, "unfreeze_milestones": UnfreezeMilestones } def instantiate_default_finetuning_callbacks(strategy): + if strategy is None: + strategy = "no_freeze" + rank_zero_warn("strategy is None. Setting strategy to `no_freeze` by default.", UserWarning) if isinstance(strategy, str): strategy = strategy.lower() if strategy in _DEFAULTS_FINETUNE_STRATEGIES: return [_DEFAULTS_FINETUNE_STRATEGIES[strategy]()] - else: - raise MisconfigurationException( - f"strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" - f". Found {strategy}" - ) + raise MisconfigurationException( + f"strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" + f". Found {strategy}" + ) return [] diff --git a/flash/core/model.py b/flash/core/model.py index 40b7f28be72..3607878ac86 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl import torch @@ -150,3 +150,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint["pipeline"] = self.data_pipeline + + def configure_finetune_callback(self): + return [] diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 9a1491f5372..75c74f6190d 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -15,12 +15,13 @@ from typing import List, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.callbacks import BaseFinetuning, Callback +from pytorch_lightning.callbacks import BaseFinetuning +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader -from flash.core.finetuning import instantiate_default_finetuning_callbacks -from flash.core.model import Task +from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks class Trainer(pl.Trainer): @@ -53,7 +54,7 @@ def fit( def finetune( self, - model: Task, + model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[pl.LightningDataModule] = None, @@ -76,29 +77,46 @@ def finetune( strategy: Should either be a string or a finetuning callback subclassing ``pytorch_lightning.callbacks.BaseFinetuning``. - Currently default strategies can be create with strings such as: + Currently, default strategies can be enabled with these strings: * ``no_freeze``, - * ``freeze`` - * ``freeze_unfreeze`` + * ``freeze``, + * ``freeze_unfreeze``, * ``unfreeze_milestones`` """ - if not isinstance(strategy, (BaseFinetuning, str)): + self._resolve_callbacks(model, strategy) + return super().fit(model, train_dataloader, val_dataloaders, datamodule) + + def _resolve_callbacks(self, model, strategy): + if strategy is not None and not isinstance(strategy, (str, BaseFinetuning)): raise MisconfigurationException( - "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback or a str" + "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning``" + f"callback or a str within {list(_DEFAULTS_FINETUNE_STRATEGIES.keys())}" ) - self._resolve_callbacks(strategy) - return super().fit(model, train_dataloader, val_dataloaders, datamodule) - - def _resolve_callbacks(self, strategy): - if sum((isinstance(c, BaseFinetuning) for c in [strategy])) > 1: - raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.") - # todo: change to ``configure_callbacks`` when merged to Lightning. callbacks = self.callbacks - if isinstance(strategy, str): - strategy = instantiate_default_finetuning_callbacks(strategy) - self.callbacks = self._merge_callbacks(callbacks, [strategy]) + + if isinstance(strategy, BaseFinetuning): + callback = strategy + else: + # todo: change to ``configure_callbacks`` when merged to Lightning. + model_callback = model.configure_finetune_callback() + if len(model_callback) > 1: + raise MisconfigurationException( + f"{model} configure_finetune_callback should create a list with only 1 callback" + ) + if len(model_callback) == 1: + if strategy is not None: + rank_zero_warn( + "The model contains a default finetune callback. " + f"The provided {strategy} will be overriden. " + "HINT: Provide a `BaseFinetuning callback as strategy to be prioritized. ", UserWarning + ) + callback = [model_callback] + else: + callback = instantiate_default_finetuning_callbacks(strategy) + + self.callbacks = self._merge_callbacks(callbacks, [callback]) @staticmethod def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index 1e463855849..e4062eac598 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -18,7 +18,7 @@ from torch.nn import functional as F from flash import ClassificationTask, Trainer -from flash.core.finetuning import NoFreeze +from flash.core.finetuning import FlashBaseFinetuning from tests.core.test_model import DummyDataset @@ -32,12 +32,9 @@ def test_finetuning(tmpdir: str, strategy): task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) if strategy == "cls": - strategy = NoFreeze() + strategy = FlashBaseFinetuning() if strategy == 'chocolat': with pytest.raises(MisconfigurationException, match="strategy should be within"): trainer.finetune(task, train_dl, val_dl, strategy=strategy) - elif strategy is None: - with pytest.raises(MisconfigurationException, match="strategy should"): - trainer.finetune(task, train_dl, val_dl, strategy=strategy) else: trainer.finetune(task, train_dl, val_dl, strategy=strategy) diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index f872de5e557..e0f63f19f28 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -51,5 +51,5 @@ def test_task_finetune(tmpdir: str): val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) - result = trainer.finetune(task, train_dl, val_dl, unfreeze_milestones=(0, 0)) + result = trainer.finetune(task, train_dl, val_dl) assert result From 583ce0cb3ca01d0967a55294e421c2ea02765728 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 17:15:34 +0000 Subject: [PATCH 11/15] update notebooks --- flash/core/trainer.py | 23 +++++---- .../finetuning/image_classification.ipynb | 47 ++++++++--------- .../finetuning/text_classification.ipynb | 50 ++++++++++--------- 3 files changed, 64 insertions(+), 56 deletions(-) diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 75c74f6190d..12b801612fd 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -88,14 +88,15 @@ def finetune( return super().fit(model, train_dataloader, val_dataloaders, datamodule) def _resolve_callbacks(self, model, strategy): + """ + This function is used to select the `BaseFinetuning` to be used for finetuning. + """ if strategy is not None and not isinstance(strategy, (str, BaseFinetuning)): raise MisconfigurationException( "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning``" f"callback or a str within {list(_DEFAULTS_FINETUNE_STRATEGIES.keys())}" ) - callbacks = self.callbacks - if isinstance(strategy, BaseFinetuning): callback = strategy else: @@ -110,20 +111,24 @@ def _resolve_callbacks(self, model, strategy): rank_zero_warn( "The model contains a default finetune callback. " f"The provided {strategy} will be overriden. " - "HINT: Provide a `BaseFinetuning callback as strategy to be prioritized. ", UserWarning + "HINT: Provide a `BaseFinetuning` callback as strategy to make it prioritized. ", UserWarning ) callback = [model_callback] else: callback = instantiate_default_finetuning_callbacks(strategy) - self.callbacks = self._merge_callbacks(callbacks, [callback]) + self.callbacks = self._merge_callbacks(self.callbacks, [callback]) @staticmethod - def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: + def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: + """ + This function keeps only 1 instance of each callback type, + extending new_callbacks with old_callbacks + """ if len(new_callbacks): - return current_callbacks + return old_callbacks new_callbacks_types = set(type(c) for c in new_callbacks) - current_callbacks_types = set(type(c) for c in current_callbacks) - override_types = new_callbacks_types.intersection(current_callbacks_types) - new_callbacks.extend(c for c in current_callbacks if type(c) not in override_types) + old_callbacks_types = set(type(c) for c in old_callbacks) + override_types = new_callbacks_types.intersection(old_callbacks_types) + new_callbacks.extend(c for c in old_callbacks if type(c) not in override_types) return new_callbacks diff --git a/flash_notebooks/finetuning/image_classification.ipynb b/flash_notebooks/finetuning/image_classification.ipynb index 1ee71ec15ef..1959e99df78 100644 --- a/flash_notebooks/finetuning/image_classification.ipynb +++ b/flash_notebooks/finetuning/image_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "dominican-savings", + "id": "thousand-manufacturer", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "threaded-coffee", + "id": "smoking-probe", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetuning an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", @@ -27,7 +27,9 @@ " \n", " - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", - " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. At training start, the backbone will be frozen, meaning its parameters won't be updated. Only the model head will be trained to properly distinguish ants and bees. On reaching first finetuning milestone, the backbone latest layers will be unfrozen and start to be trained. On reaching the second finetuning milestone, the remaining layers of the backend will be unfrozen and the entire model will be trained. In Flash, `trainer.finetune(..., unfreeze_milestones=(first_milestone, second_milestone))`.\n", + " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `from pytorch_lightning.callbacks import BaseFinetuning`.\n", + " \n", + " \n", "\n", " \n", "\n", @@ -41,7 +43,7 @@ { "cell_type": "code", "execution_count": null, - "id": "handmade-timing", + "id": "thermal-fraction", "metadata": {}, "outputs": [], "source": [ @@ -52,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "through-edwards", + "id": "cognitive-haven", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +65,7 @@ }, { "cell_type": "markdown", - "id": "hybrid-adapter", + "id": "afraid-straight", "metadata": {}, "source": [ "## 1. Download data\n", @@ -73,7 +75,7 @@ { "cell_type": "code", "execution_count": null, - "id": "amateur-disposal", + "id": "advisory-narrow", "metadata": {}, "outputs": [], "source": [ @@ -82,7 +84,7 @@ }, { "cell_type": "markdown", - "id": "front-metallic", + "id": "trying-group", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "hazardous-means", + "id": "stuck-composition", "metadata": {}, "outputs": [], "source": [ @@ -118,7 +120,7 @@ }, { "cell_type": "markdown", - "id": "defined-mouse", + "id": "irish-scenario", "metadata": {}, "source": [ "### 3. Build the model\n", @@ -130,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "internal-playback", + "id": "opening-nomination", "metadata": {}, "outputs": [], "source": [ @@ -139,7 +141,7 @@ }, { "cell_type": "markdown", - "id": "million-tower", + "id": "breathing-element", "metadata": {}, "source": [ "### 4. Create the trainer. Run once on data\n", @@ -156,7 +158,7 @@ { "cell_type": "code", "execution_count": null, - "id": "centered-paris", + "id": "earlier-jordan", "metadata": {}, "outputs": [], "source": [ @@ -165,26 +167,25 @@ }, { "cell_type": "markdown", - "id": "special-fence", + "id": "extreme-scene", "metadata": {}, "source": [ - "### 5. Finetune the model\n", - "The `unfreeze_milestones=(0, 1)` will unfreeze the latest layers of the backbone on epoch `0` and the rest of the backbone on epoch `1`. " + "### 5. Finetune the model" ] }, { "cell_type": "code", "execution_count": null, - "id": "local-taylor", + "id": "tired-underground", "metadata": {}, "outputs": [], "source": [ - "trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" + "trainer.finetune(model, datamodule=datamodule, strategy=\"freeze_unfreeze\")" ] }, { "cell_type": "markdown", - "id": "municipal-kentucky", + "id": "smooth-european", "metadata": {}, "source": [ "### 6. Test the model" @@ -193,7 +194,7 @@ { "cell_type": "code", "execution_count": null, - "id": "simplified-bundle", + "id": "sexual-tender", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "familiar-territory", + "id": "athletic-nutrition", "metadata": {}, "source": [ "### 7. Save it!" @@ -211,7 +212,7 @@ { "cell_type": "code", "execution_count": null, - "id": "injured-mineral", + "id": "pleasant-canon", "metadata": {}, "outputs": [], "source": [ @@ -220,7 +221,7 @@ }, { "cell_type": "markdown", - "id": "tested-experience", + "id": "incident-basket", "metadata": {}, "source": [ "\n", diff --git a/flash_notebooks/finetuning/text_classification.ipynb b/flash_notebooks/finetuning/text_classification.ipynb index 8079575e0e4..28cb7487937 100644 --- a/flash_notebooks/finetuning/text_classification.ipynb +++ b/flash_notebooks/finetuning/text_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "digital-quilt", + "id": "prerequisite-straight", "metadata": {}, "source": [ "
\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "empty-request", + "id": "coastal-bible", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).\n", @@ -31,7 +31,7 @@ "- 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", "\n", - "- 4. Train the target model on a target dataset, such as IMDB Dataset to learn to predict the associated sentiment of movie reviews. At training start, the backbone will be frozen, meaning its parameters won't be updated. Only the model head will be trained to between negative and positive reviews. On reaching first finetuning milestone, the backbone latest layers will be unfrozen and start to be trained. On reaching the second finetuning milestone, the remaining layers of the backend will be unfrozen and the entire model will be trained. In Flash, `unfreeze_milestones` controls those milestone and be used as such `trainer.finetune(..., unfreeze_milestones=(first_milestone, second_milestone))`.\n", + "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `from pytorch_lightning.callbacks import BaseFinetuning`.\n", "\n", "---\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "another-might", + "id": "sharp-techno", "metadata": {}, "source": [ "### Setup \n", @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ideal-summary", + "id": "posted-blair", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ { "cell_type": "code", "execution_count": null, - "id": "straight-commission", + "id": "double-swedish", "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "markdown", - "id": "classical-snake", + "id": "outside-garlic", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "lined-standing", + "id": "tired-lender", "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,7 @@ }, { "cell_type": "markdown", - "id": "endangered-heavy", + "id": "daily-marijuana", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "posted-chosen", + "id": "standing-commons", "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "markdown", - "id": "cognitive-compact", + "id": "fantastic-mortality", "metadata": { "jupyter": { "outputs_hidden": true @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "underlying-liberia", + "id": "prompt-azerbaijan", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "democratic-interaction", + "id": "cubic-crystal", "metadata": { "jupyter": { "outputs_hidden": true @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "adopted-caution", + "id": "mineral-phrase", "metadata": {}, "outputs": [], "source": [ @@ -169,29 +169,31 @@ }, { "cell_type": "markdown", - "id": "integral-access", + "id": "brown-scoop", "metadata": { "jupyter": { "outputs_hidden": true } }, "source": [ - "### 5. Fine-tune the model" + "### 5. Fine-tune the model\n", + "\n", + "The backbone won't be freezed and the entire model will be finetuned on the imdb dataset " ] }, { "cell_type": "code", "execution_count": null, - "id": "enormous-botswana", + "id": "reliable-hampshire", "metadata": {}, "outputs": [], "source": [ - "trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" + "trainer.finetune(model, datamodule=datamodule, strategy=\"no_freeze\")" ] }, { "cell_type": "markdown", - "id": "cellular-baking", + "id": "unlimited-duplicate", "metadata": { "jupyter": { "outputs_hidden": true @@ -204,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "demanding-headline", + "id": "federal-quarter", "metadata": {}, "outputs": [], "source": [ @@ -213,7 +215,7 @@ }, { "cell_type": "markdown", - "id": "charged-investigator", + "id": "defensive-committee", "metadata": { "jupyter": { "outputs_hidden": true @@ -226,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "early-ridge", + "id": "disciplinary-background", "metadata": {}, "outputs": [], "source": [ @@ -235,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "detailed-direction", + "id": "increased-filter", "metadata": {}, "source": [ "\n", @@ -296,4 +298,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} From 6bf0064ab3b0184e864c4827d339f9f16aba19b8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 17:18:18 +0000 Subject: [PATCH 12/15] update typo --- flash/core/finetuning.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 5e6298b03b0..d94de89fd19 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from typing import List, Union import pytorch_lightning as pl @@ -21,8 +20,6 @@ from torch import nn from torch.optim import Optimizer -_EXCLUDE_PARAMTERS = ("self", "args", "kwargs") - class FlashBaseFinetuning(BaseFinetuning): @@ -46,13 +43,12 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo def freeze_before_training(self, pl_module: pl.LightningModule) -> None: self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) - @staticmethod - def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = True): + def freeze_using_attr_names(self, pl_module, attr_names: List[str], train_bn: bool = True): for attr_name in attr_names: attr = getattr(pl_module, attr_name, None) if attr is None or not isinstance(attr, nn.Module): MisconfigurationException(f"Your model must have a {attr} attribute") - BaseFinetuning.freeze(module=attr, train_bn=train_bn) + self.freeze(module=attr, train_bn=train_bn) class FreezeUnfreeze(FlashBaseFinetuning): From 504ddfeab5c584099067414dc43e854235db40f4 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 1 Feb 2021 17:40:53 +0000 Subject: [PATCH 13/15] Update flash_notebooks/finetuning/image_classification.ipynb MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- flash_notebooks/finetuning/image_classification.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_notebooks/finetuning/image_classification.ipynb b/flash_notebooks/finetuning/image_classification.ipynb index 1959e99df78..4cd82ec404c 100644 --- a/flash_notebooks/finetuning/image_classification.ipynb +++ b/flash_notebooks/finetuning/image_classification.ipynb @@ -27,7 +27,7 @@ " \n", " - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", - " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `from pytorch_lightning.callbacks import BaseFinetuning`.\n", + " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `pytorch_lightning.callbacks.BaseFinetuning`.\n", " \n", " \n", "\n", From 2f37a368ef996970319c9b25b3e1d2d113657a7d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 17:45:29 +0000 Subject: [PATCH 14/15] resolve comments --- README.md | 16 +++---- flash/core/finetuning.py | 15 +------ flash/core/trainer.py | 10 +++-- .../finetuning/image_classification.py | 2 +- .../finetuning/text_classification.ipynb | 42 +++++++++---------- 5 files changed, 38 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index dfad9593289..62538270312 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ model = ImageClassifier(num_classes=datamodule.num_classes) trainer = flash.Trainer(max_epochs=1) # 5. Finetune the model -trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 7. Save it! trainer.save_checkpoint("image_classification_model.pt") @@ -151,13 +151,13 @@ Flash is built as a collection of community-built tasks. A task is highly opinio ### Example 1: Image classification Flash has an ImageClassification task to tackle any image classification problem. - +
View example To illustrate, Let's say we wanted to develop a model that could classify between ants and bees. - + - + Here we classify ants vs bees. ```python @@ -208,7 +208,7 @@ Flash has a TextClassification task to tackle any text classification problem.
View example To illustrate, say you wanted to classify movie reviews as positive or negative. - + ```python import flash from flash import download_data @@ -261,9 +261,9 @@ Flash has a TabularClassification task to tackle any tabular classification prob
View example - - To illustrate, say we want to build a model to predict if a passenger survived on the Titanic. - + + To illustrate, say we want to build a model to predict if a passenger survived on the Titanic. + ```python from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall import flash diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index d94de89fd19..c68f52fde00 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -57,9 +57,6 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo super().__init__(attr_names, train_bn) self.unfreeze_epoch = unfreeze_epoch - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) - def finetunning_function( self, pl_module: pl.LightningModule, @@ -68,10 +65,7 @@ def finetunning_function( opt_idx: int, ) -> None: if epoch == self.unfreeze_epoch: - modules = [] - for attr_name in self.attr_names: - modules.append(getattr(pl_module, attr_name)) - + modules = [getattr(pl_module, attr_name) for attr_name in self.attr_names] self.unfreeze_and_add_param_group( module=modules, optimizer=optimizer, @@ -93,9 +87,6 @@ def __init__( super().__init__(attr_names, train_bn) - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) - def finetunning_function( self, pl_module: pl.LightningModule, @@ -105,8 +96,7 @@ def finetunning_function( ) -> None: backbone_modules = list(pl_module.backbone.modules()) if epoch == self.unfreeze_milestones[0]: - # unfreeze 5 last layers - # TODO last N layers should be parameter + # unfreeze num_layers last layers self.unfreeze_and_add_param_group( module=backbone_modules[-self.num_layers:], optimizer=optimizer, @@ -115,7 +105,6 @@ def finetunning_function( elif epoch == self.unfreeze_milestones[1]: # unfreeze remaining layers - # TODO last N layers should be parameter self.unfreeze_and_add_param_group( module=backbone_modules[:-self.num_layers], optimizer=optimizer, diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 12b801612fd..e570d4ae2d1 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -61,6 +61,7 @@ def finetune( strategy: Optional[Union[str, BaseFinetuning]] = None, ): r""" + Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers of the backbone throughout training layers of the backbone throughout training. @@ -77,11 +78,12 @@ def finetune( strategy: Should either be a string or a finetuning callback subclassing ``pytorch_lightning.callbacks.BaseFinetuning``. + Currently, default strategies can be enabled with these strings: - * ``no_freeze``, - * ``freeze``, - * ``freeze_unfreeze``, - * ``unfreeze_milestones`` + - ``no_freeze``, + - ``freeze``, + - ``freeze_unfreeze``, + - ``unfreeze_milestones`` """ self._resolve_callbacks(model, strategy) diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 7643e851311..b5202c16611 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -18,7 +18,7 @@ # 3. Build the model model = ImageClassifier(num_classes=datamodule.num_classes) - # 4. Create the trainer. Run once on data + # 4. Create the trainer. Run twice on data trainer = flash.Trainer(max_epochs=2) # 5. Train the model diff --git a/flash_notebooks/finetuning/text_classification.ipynb b/flash_notebooks/finetuning/text_classification.ipynb index 28cb7487937..34411232aaa 100644 --- a/flash_notebooks/finetuning/text_classification.ipynb +++ b/flash_notebooks/finetuning/text_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "prerequisite-straight", + "id": "optical-barrel", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "coastal-bible", + "id": "rolled-scoop", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).\n", @@ -31,7 +31,7 @@ "- 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", "\n", - "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `from pytorch_lightning.callbacks import BaseFinetuning`.\n", + "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `pytorch_lightning.callbacks.BaseFinetuning`.\n", "\n", "---\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "sharp-techno", + "id": "pleasant-benchmark", "metadata": {}, "source": [ "### Setup \n", @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": null, - "id": "posted-blair", + "id": "suspended-announcement", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ { "cell_type": "code", "execution_count": null, - "id": "double-swedish", + "id": "appreciated-internship", "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "markdown", - "id": "outside-garlic", + "id": "excessive-private", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "tired-lender", + "id": "noted-father", "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,7 @@ }, { "cell_type": "markdown", - "id": "daily-marijuana", + "id": "naval-rogers", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "standing-commons", + "id": "monetary-album", "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "markdown", - "id": "fantastic-mortality", + "id": "published-vision", "metadata": { "jupyter": { "outputs_hidden": true @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "prompt-azerbaijan", + "id": "focused-claim", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "cubic-crystal", + "id": "primary-battery", "metadata": { "jupyter": { "outputs_hidden": true @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "mineral-phrase", + "id": "great-austria", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "brown-scoop", + "id": "corporate-sequence", "metadata": { "jupyter": { "outputs_hidden": true @@ -184,7 +184,7 @@ { "cell_type": "code", "execution_count": null, - "id": "reliable-hampshire", + "id": "opponent-visit", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "unlimited-duplicate", + "id": "sunrise-questionnaire", "metadata": { "jupyter": { "outputs_hidden": true @@ -206,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "federal-quarter", + "id": "certain-pizza", "metadata": {}, "outputs": [], "source": [ @@ -215,7 +215,7 @@ }, { "cell_type": "markdown", - "id": "defensive-committee", + "id": "loose-march", "metadata": { "jupyter": { "outputs_hidden": true @@ -228,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "disciplinary-background", + "id": "loose-culture", "metadata": {}, "outputs": [], "source": [ @@ -237,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "increased-filter", + "id": "quarterly-dominican", "metadata": {}, "source": [ "\n", From 7ccfd814d043e7beed668dce3642159c87aaa649 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 18:04:54 +0000 Subject: [PATCH 15/15] remove set -e --- .github/workflows/ci-notebook.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-notebook.yml b/.github/workflows/ci-notebook.yml index 97b57cb580a..daa5ec4e409 100644 --- a/.github/workflows/ci-notebook.yml +++ b/.github/workflows/ci-notebook.yml @@ -57,15 +57,14 @@ jobs: with: path: flash_examples/predict # This path is specific to Ubuntu # Look to see if there is a cache hit for the corresponding requirements file - key: flash-datasets_predict + key: flash-datasets_predict - name: Run Notebooks run: | - set -e jupyter nbconvert --to script flash_notebooks/finetuning/tabular_classification.ipynb jupyter nbconvert --to script flash_notebooks/predict/classify_image.ipynb jupyter nbconvert --to script flash_notebooks/predict/classify_tabular.ipynb ipython flash_notebooks/finetuning/tabular_classification.py ipython flash_notebooks/predict/classify_image.py - ipython flash_notebooks/predict/classify_tabular.py \ No newline at end of file + ipython flash_notebooks/predict/classify_tabular.py