Skip to content

Commit

Permalink
Bugfix/5487 auto lr ordering (#5638)
Browse files Browse the repository at this point in the history
* started to write failing test. just getting into the framework...

* started to write failing test. just getting into the framework...

* added failing test for misconfiguration of lr finder

* made test startup quickly. making sure without the fix it also fails slowly

* improved test

* fixed for linter

* fixed for linter

* yet another fix for the linter

* yet another fix for the linter

* fixed comment by @carmocca

* fixed comment by @carmocca

* Fix test

* chlog

* Apply suggestions from code review

* Fix test

* Update pytorch_lightning/tuner/lr_finder.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/tuner/lr_finder.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update tests/trainer/test_lr_finder.py

* Update pytorch_lightning/tuner/lr_finder.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/tuner/lr_finder.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/tuner/lr_finder.py

* Update tests/trainer/test_lr_finder.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
6 people authored and tchaton committed Feb 5, 2021
1 parent 2780d59 commit 966da1e
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Remove unnecessary intermediate layers in Dockerfiles ([#5697](https://github.com/PyTorchLightning/pytorch-lightning/pull/5697))


- Fixed auto learning rate ordering ([#5638](https://github.com/PyTorchLightning/pytorch-lightning/pull/5638))


## [1.1.6] - 2021-01-26

### Changed
Expand Down
44 changes: 24 additions & 20 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import importlib
import os
from functools import wraps
from typing import Callable, List, Optional, Sequence, Union
from typing import Any, Callable, List, Optional, Sequence, Union

import numpy as np
import torch
Expand All @@ -40,34 +40,38 @@
from tqdm import tqdm


def _determine_lr_attr_name(trainer, model: LightningModule) -> str:
if isinstance(trainer.auto_lr_find, str):
if not lightning_hasattr(model, trainer.auto_lr_find):
raise MisconfigurationException(
f'`auto_lr_find` was set to {trainer.auto_lr_find}, however'
' could not find this as a field in `model` or `model.hparams`.'
)
return trainer.auto_lr_find

attr_options = ('lr', 'learning_rate')
for attr in attr_options:
if lightning_hasattr(model, attr):
return attr

raise MisconfigurationException(
'When `auto_lr_find=True`, either `model` or `model.hparams` should'
f' have one of these fields: {attr_options} overridden.'
)


def _run_lr_finder_internally(trainer, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_attr_name = _determine_lr_attr_name(trainer, model)
lr_finder = lr_find(trainer, model)

if lr_finder is None:
return

lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
if isinstance(trainer.auto_lr_find, str):
# Try to find requested field, may be nested
if lightning_hasattr(model, trainer.auto_lr_find):
lightning_setattr(model, trainer.auto_lr_find, lr)
else:
raise MisconfigurationException(
f'`auto_lr_find` was set to {trainer.auto_lr_find}, however'
' could not find this as a field in `model` or `model.hparams`.')
else:
if lightning_hasattr(model, 'lr'):
lightning_setattr(model, 'lr', lr)
elif lightning_hasattr(model, 'learning_rate'):
lightning_setattr(model, 'learning_rate', lr)
else:
raise MisconfigurationException(
'When auto_lr_find is set to True, expects that `model` or'
' `model.hparams` either has field `lr` or `learning_rate`'
' that can overridden')
lightning_setattr(model, lr_attr_name, lr)

log.info(f'Learning rate set to {lr}')


Expand Down
1 change: 1 addition & 0 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
""" Test that tuner algorithms are skipped if fast dev run is enabled """

model = BoringModel()
model.lr = 0.1 # avoid no-lr-found exception
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,6 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir):
train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True)

class ExtendedBoringModel(BoringModel):

def train_dataloader(self):
return train

Expand Down
11 changes: 10 additions & 1 deletion tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base import BoringModel, EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule


Expand Down Expand Up @@ -265,3 +265,12 @@ def test_suggestion_with_non_finite_values(tmpdir):

assert before_lr == after_lr, \
'Learning rate was altered because of non-finite loss values'


def test_lr_finder_fails_fast_on_bad_config(tmpdir):
""" Test that tune fails if the model does not have a lr BEFORE running lr find """
# note: this did not raise an exception before #5638 because lr_find is skipped
# during fast_dev_run and the lr attribute check was done after lr_find
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, auto_lr_find=True)
with pytest.raises(MisconfigurationException, match='should have one of these fields'):
trainer.tune(BoringModel())

0 comments on commit 966da1e

Please sign in to comment.