Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update lr_finder to check for attribute if not running fast_dev_run #5990

Merged
merged 7 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Refactored `EpochResultStore` ([#5522](https://github.com/PyTorchLightning/pytorch-lightning/pull/5522))


- Update `lr_finder` to check for attribute if not running `fast_dev_run` ([#5990](https://github.com/PyTorchLightning/pytorch-lightning/pull/5990))


- LightningOptimizer manual optimizer is more flexible and expose `toggle_model` ([#5771](https://github.com/PyTorchLightning/pytorch-lightning/pull/5771))


Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def scale_batch_size(
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places

- `model`
- `model.hparams`
- `model.datamodule`
- `trainer.datamodule` (the datamodule passed to the tune method)
- ``model``
- ``model.hparams``
- ``model.datamodule``
- ``trainer.datamodule`` (the datamodule passed to the tune method)

**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
Expand Down
50 changes: 25 additions & 25 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,6 @@ def _determine_lr_attr_name(trainer, model: LightningModule) -> str:
)


def _run_lr_finder_internally(trainer, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_attr_name = _determine_lr_attr_name(trainer, model)
lr_finder = lr_find(trainer, model)
if lr_finder is None:
return

lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
lightning_setattr(model, lr_attr_name, lr)

log.info(f'Learning rate set to {lr}')


def lr_find(
trainer,
model: LightningModule,
Expand All @@ -86,16 +71,17 @@ def lr_find(
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
r"""
`lr_find` enables the user to do a range test of good initial learning rates,
``lr_find`` enables the user to do a range test of good initial learning rates,
to reduce the amount of guesswork in picking a good starting learning rate.

Args:
model: Model to do range testing for

train_dataloader: A PyTorch
`DataLoader` with training samples. If the model has
``DataLoader`` with training samples. If the model has
a predefined train_dataloader method, this will be skipped.

min_lr: minimum learning rate to investigate
Expand All @@ -104,19 +90,21 @@ def lr_find(

num_training: number of learning rates to test

mode: search strategy, either 'linear' or 'exponential'. If set to
'linear' the learning rate will be searched by linearly increasing
after each batch. If set to 'exponential', will increase learning
rate exponentially.
mode: Search strategy to update learning rate after each batch:

- ``'exponential'`` (default): Will increase the learning rate exponentially.
- ``'linear'``: Will increase the learning rate linearly.

early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.

datamodule: An optional `LightningDataModule` which holds the training
and validation dataloader(s). Note that the `train_dataloader` and
`val_dataloaders` parameters cannot be used at the same time as
this parameter, or a `MisconfigurationException` will be raised.
datamodule: An optional ``LightningDataModule`` which holds the training
and validation dataloader(s). Note that the ``train_dataloader`` and
``val_dataloaders`` parameters cannot be used at the same time as
this parameter, or a ``MisconfigurationException`` will be raised.

update_attr: Whether to update the learning rate attribute or not.


Example::
Expand Down Expand Up @@ -144,6 +132,10 @@ def lr_find(
rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning)
return

# Determine lr attr
if update_attr:
lr_attr_name = _determine_lr_attr_name(trainer, model)

save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')

__lr_finder_dump_params(trainer, model)
Expand Down Expand Up @@ -200,6 +192,14 @@ def lr_find(
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

# Update lr attr if required
if update_attr:
lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
lightning_setattr(model, lr_attr_name, lr)
log.info(f'Learning rate set to {lr}')
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

return lr_finder


Expand Down
19 changes: 9 additions & 10 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find
from pytorch_lightning.tuner.lr_finder import lr_find


class Tuner:
Expand Down Expand Up @@ -53,7 +53,7 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule):

# Run learning rate finder:
if self.trainer.auto_lr_find:
self.internal_find_lr(model)
self.lr_find(model, update_attr=True)

def scale_batch_size(
self,
Expand Down Expand Up @@ -92,10 +92,10 @@ def scale_batch_size(
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places

- `model`
- `model.hparams`
- `model.datamodule`
- `trainer.datamodule` (the datamodule passed to the tune method)
- ``model``
- ``model.hparams``
- ``model.datamodule``
- ``trainer.datamodule`` (the datamodule passed to the tune method)

**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
Expand All @@ -122,7 +122,8 @@ def lr_find(
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
return lr_find(
self.trainer,
Expand All @@ -135,10 +136,8 @@ def lr_find(
mode,
early_stop_threshold,
datamodule,
update_attr,
)

def internal_find_lr(self, model: LightningModule):
return _run_lr_finder_internally(self.trainer, model)

def pick_multiple_gpus(self, num_gpus: int):
return pick_multiple_gpus(num_gpus)
3 changes: 2 additions & 1 deletion tests/helpers/simple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
class ClassificationModel(LightningModule):

def __init__(self, lr=0.01):
self.lr = lr
super().__init__()

self.lr = lr
for i in range(3):
setattr(self, f"layer_{i}", nn.Linear(32, 32))
setattr(self, f"layer_{i}a", torch.nn.ReLU())
Expand Down
17 changes: 7 additions & 10 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel
from tests.helpers.datamodules import TrialMNISTDataModule
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.simple_models import ClassificationModel


def test_error_on_more_than_1_optimizer(tmpdir):
Expand Down Expand Up @@ -180,12 +181,10 @@ def test_datamodule_parameter(tmpdir):
""" Test that the datamodule parameter works """

# trial datamodule
dm = TrialMNISTDataModule(tmpdir)
dm = ClassifDataModule()
model = ClassificationModel()

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

before_lr = hparams.get('learning_rate')
before_lr = model.lr
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -194,7 +193,7 @@ def test_datamodule_parameter(tmpdir):

lrfinder = trainer.tuner.lr_find(model, datamodule=dm)
after_lr = lrfinder.suggestion()
model.learning_rate = after_lr
model.lr = after_lr

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
Expand Down Expand Up @@ -271,8 +270,6 @@ def test_suggestion_with_non_finite_values(tmpdir):

def test_lr_finder_fails_fast_on_bad_config(tmpdir):
""" Test that tune fails if the model does not have a lr BEFORE running lr find """
# note: this did not raise an exception before #5638 because lr_find is skipped
# during fast_dev_run and the lr attribute check was done after lr_find
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, auto_lr_find=True)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True)
with pytest.raises(MisconfigurationException, match='should have one of these fields'):
trainer.tune(BoringModel())