From 5092f499211674f3308f4fe8750e9f0c96762668 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 16 Feb 2021 01:50:34 +0530 Subject: [PATCH 1/3] ref lr_finder a bit --- pytorch_lightning/tuner/batch_size_scaling.py | 8 ++-- pytorch_lightning/tuner/lr_finder.py | 48 +++++++++---------- pytorch_lightning/tuner/tuning.py | 19 ++++---- tests/helpers/simple_models.py | 3 +- tests/trainer/test_lr_finder.py | 17 +++---- 5 files changed, 46 insertions(+), 49 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 56e853385c68e..3a52b6dd2e8fa 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -63,10 +63,10 @@ def scale_batch_size( It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places - - `model` - - `model.hparams` - - `model.datamodule` - - `trainer.datamodule` (the datamodule passed to the tune method) + - ``model`` + - ``model.hparams`` + - ``model.datamodule`` + - ``trainer.datamodule`` (the datamodule passed to the tune method) **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader or datamodule. diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 83c0d51089bd9..8903721ff70f9 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -60,21 +60,6 @@ def _determine_lr_attr_name(trainer, model: LightningModule) -> str: ) -def _run_lr_finder_internally(trainer, model: LightningModule): - """ Call lr finder internally during Trainer.fit() """ - lr_attr_name = _determine_lr_attr_name(trainer, model) - lr_finder = lr_find(trainer, model) - if lr_finder is None: - return - - lr = lr_finder.suggestion() - - # TODO: log lr.results to self.logger - lightning_setattr(model, lr_attr_name, lr) - - log.info(f'Learning rate set to {lr}') - - def lr_find( trainer, model: LightningModule, @@ -86,16 +71,17 @@ def lr_find( mode: str = 'exponential', early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None, + update_attr: bool = False, ): r""" - `lr_find` enables the user to do a range test of good initial learning rates, + ``lr_find`` enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. Args: model: Model to do range testing for train_dataloader: A PyTorch - `DataLoader` with training samples. If the model has + ``DataLoader`` with training samples. If the model has a predefined train_dataloader method, this will be skipped. min_lr: minimum learning rate to investigate @@ -104,19 +90,21 @@ def lr_find( num_training: number of learning rates to test - mode: search strategy, either 'linear' or 'exponential'. If set to - 'linear' the learning rate will be searched by linearly increasing - after each batch. If set to 'exponential', will increase learning - rate exponentially. + mode: Search strategy to update learning rate after each batch: + + - ``'exponential'`` [default]: Will increase learning rate exponentially. + - ``'linear'``: Will increate learning rate linearly. early_stop_threshold: threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. - datamodule: An optional `LightningDataModule` which holds the training + datamodule: An optional ``LightningDataModule`` which holds the training and validation dataloader(s). Note that the `train_dataloader` and - `val_dataloaders` parameters cannot be used at the same time as - this parameter, or a `MisconfigurationException` will be raised. + ``val_dataloaders`` parameters cannot be used at the same time as + this parameter, or a ``MisconfigurationException`` will be raised. + + update_attr: Whether to update the learning rate attribute or not. Example:: @@ -144,6 +132,10 @@ def lr_find( rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return + # Determine lr attr + if update_attr: + lr_attr_name = _determine_lr_attr_name(trainer, model) + save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') __lr_finder_dump_params(trainer, model) @@ -200,6 +192,14 @@ def lr_find( if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() + # Update lr attr if required + if update_attr: + lr = lr_finder.suggestion() + + # TODO: log lr.results to self.logger + lightning_setattr(model, lr_attr_name, lr) + log.info(f'Learning rate set to {lr}') + return lr_finder diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 314821bd81e02..06475547b03f2 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -20,7 +20,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size -from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find +from pytorch_lightning.tuner.lr_finder import lr_find class Tuner: @@ -53,7 +53,7 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): # Run learning rate finder: if self.trainer.auto_lr_find: - self.internal_find_lr(model) + self.lr_find(model, update_attr=True) def scale_batch_size( self, @@ -92,10 +92,10 @@ def scale_batch_size( It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places - - `model` - - `model.hparams` - - `model.datamodule` - - `trainer.datamodule` (the datamodule passed to the tune method) + - ``model`` + - ``model.hparams`` + - ``model.datamodule`` + - ``trainer.datamodule`` (the datamodule passed to the tune method) **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader or datamodule. @@ -122,7 +122,8 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional[LightningDataModule] = None + datamodule: Optional[LightningDataModule] = None, + update_attr: bool = False, ): return lr_find( self.trainer, @@ -135,10 +136,8 @@ def lr_find( mode, early_stop_threshold, datamodule, + update_attr, ) - def internal_find_lr(self, model: LightningModule): - return _run_lr_finder_internally(self.trainer, model) - def pick_multiple_gpus(self, num_gpus: int): return pick_multiple_gpus(num_gpus) diff --git a/tests/helpers/simple_models.py b/tests/helpers/simple_models.py index c33c470d043b7..1abeb1f00206a 100644 --- a/tests/helpers/simple_models.py +++ b/tests/helpers/simple_models.py @@ -22,8 +22,9 @@ class ClassificationModel(LightningModule): def __init__(self, lr=0.01): - self.lr = lr super().__init__() + + self.lr = lr for i in range(3): setattr(self, f"layer_{i}", nn.Linear(32, 32)) setattr(self, f"layer_{i}a", torch.nn.ReLU()) diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 82a76bc229f0d..863cb7d0e838b 100644 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -21,7 +21,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from tests.helpers import BoringModel -from tests.helpers.datamodules import TrialMNISTDataModule +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.simple_models import ClassificationModel def test_error_on_more_than_1_optimizer(tmpdir): @@ -180,12 +181,10 @@ def test_datamodule_parameter(tmpdir): """ Test that the datamodule parameter works """ # trial datamodule - dm = TrialMNISTDataModule(tmpdir) + dm = ClassifDataModule() + model = ClassificationModel() - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - before_lr = hparams.get('learning_rate') + before_lr = model.lr # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, @@ -194,7 +193,7 @@ def test_datamodule_parameter(tmpdir): lrfinder = trainer.tuner.lr_find(model, datamodule=dm) after_lr = lrfinder.suggestion() - model.learning_rate = after_lr + model.lr = after_lr assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' @@ -271,8 +270,6 @@ def test_suggestion_with_non_finite_values(tmpdir): def test_lr_finder_fails_fast_on_bad_config(tmpdir): """ Test that tune fails if the model does not have a lr BEFORE running lr find """ - # note: this did not raise an exception before #5638 because lr_find is skipped - # during fast_dev_run and the lr attribute check was done after lr_find - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, auto_lr_find=True) + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True) with pytest.raises(MisconfigurationException, match='should have one of these fields'): trainer.tune(BoringModel()) From 90b591164d60ad75055ab3d114b6f4f38d846c4d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 16 Feb 2021 02:46:14 +0530 Subject: [PATCH 2/3] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 306500c3e6f42..b2bb84c792d83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -190,6 +190,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored `EpochResultStore` ([#5522](https://github.com/PyTorchLightning/pytorch-lightning/pull/5522)) +- Update `lr_finder` to check for attribute if not running `fast_dev_run` ([#5990](https://github.com/PyTorchLightning/pytorch-lightning/pull/5990)) + + ### Deprecated - Function `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) From 1f293fa734344e23303beab031bec3959e419cb6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Feb 2021 01:07:27 +0100 Subject: [PATCH 3/3] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/tuner/lr_finder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 8903721ff70f9..cf29799a05a5b 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -92,15 +92,15 @@ def lr_find( mode: Search strategy to update learning rate after each batch: - - ``'exponential'`` [default]: Will increase learning rate exponentially. - - ``'linear'``: Will increate learning rate linearly. + - ``'exponential'`` (default): Will increase the learning rate exponentially. + - ``'linear'``: Will increase the learning rate linearly. early_stop_threshold: threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. datamodule: An optional ``LightningDataModule`` which holds the training - and validation dataloader(s). Note that the `train_dataloader` and + and validation dataloader(s). Note that the ``train_dataloader`` and ``val_dataloaders`` parameters cannot be used at the same time as this parameter, or a ``MisconfigurationException`` will be raised.