Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/5487 auto lr ordering #5638

Merged
merged 35 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7952b79
started to write failing test. just getting into the framework...
noamzilo Jan 15, 2021
7e2fa62
started to write failing test. just getting into the framework...
noamzilo Jan 17, 2021
edc9ed9
added failing test for misconfiguration of lr finder
noamzilo Jan 23, 2021
d24b797
made test startup quickly. making sure without the fix it also fails …
noamzilo Jan 24, 2021
b51a754
improved test
noamzilo Jan 24, 2021
0ca79c8
Merge branch 'master' into bugfix/5487_auto_lr_ordering
noamzilo Jan 24, 2021
c34042c
fixed for linter
noamzilo Jan 24, 2021
0ccefb2
Merge branch 'bugfix/5487_auto_lr_ordering' of https://github.com/noa…
noamzilo Jan 24, 2021
e7a6c36
fixed for linter
noamzilo Jan 24, 2021
874178c
yet another fix for the linter
noamzilo Jan 24, 2021
404bc9c
yet another fix for the linter
noamzilo Jan 24, 2021
7f1bae7
Merge branch 'master' into bugfix/5487_auto_lr_ordering
noamzilo Jan 25, 2021
c6dedc8
fixed comment by @carmocca
noamzilo Jan 25, 2021
03756f2
fixed comment by @carmocca
noamzilo Jan 25, 2021
d9042e0
Merge branch 'master' into bugfix/5487_auto_lr_ordering
noamzilo Jan 27, 2021
8e6862d
Fix test
carmocca Jan 28, 2021
cc81a31
Merge branch 'master' into bugfix/5487_auto_lr_ordering
carmocca Jan 28, 2021
78500a4
chlog
Borda Jan 29, 2021
6e9382a
Apply suggestions from code review
Borda Jan 29, 2021
70eb21f
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 29, 2021
3bc6ae1
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 29, 2021
fa7255e
Fix test
carmocca Jan 29, 2021
ecbc80d
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 29, 2021
bc6055c
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 30, 2021
8d01dd3
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 31, 2021
7ee4056
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 31, 2021
15ec3cd
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
1008706
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
030d6e8
Update tests/trainer/test_lr_finder.py
carmocca Feb 1, 2021
70359a3
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
99f09f7
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
cff6bc6
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Feb 1, 2021
035b72f
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
71b36d2
Update tests/trainer/test_lr_finder.py
carmocca Feb 1, 2021
042bfda
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Feb 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Remove unnecessary intermediate layers in Dockerfiles ([#5697](https://github.com/PyTorchLightning/pytorch-lightning/pull/5697))


- Fixed auto learning rate ordering ([#5638](https://github.com/PyTorchLightning/pytorch-lightning/pull/5638))


## [1.1.6] - 2021-01-26

### Changed
Expand Down
43 changes: 23 additions & 20 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import importlib
import logging
import os
from typing import List, Optional, Sequence, Union, Callable
from typing import List, Optional, Sequence, Union, Callable, Any
from functools import wraps

import numpy as np
Expand Down Expand Up @@ -42,34 +42,37 @@
log = logging.getLogger(__name__)


def __determine_lr_attr_name(trainer, model: LightningModule) -> str:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(trainer.auto_lr_find, str):
if not lightning_hasattr(model, trainer.auto_lr_find):
raise MisconfigurationException(
f'`auto_lr_find` was set to {trainer.auto_lr_find}, however'
' could not find this as a field in `model` or `model.hparams`.')
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return trainer.auto_lr_find

attr_options = ('lr', 'learning_rate')
for attr in attr_options:
if lightning_hasattr(model, attr):
return attr

raise MisconfigurationException(
'When auto_lr_find is set to True, expects that `model` or'
f' `model.hparams` either has one of these fields {attr_options}'
' that can overridden')
carmocca marked this conversation as resolved.
Show resolved Hide resolved


def _run_lr_finder_internally(trainer, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_attr_name = __determine_lr_attr_name(trainer, model)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
lr_finder = lr_find(trainer, model)

if lr_finder is None:
return

lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
if isinstance(trainer.auto_lr_find, str):
# Try to find requested field, may be nested
if lightning_hasattr(model, trainer.auto_lr_find):
lightning_setattr(model, trainer.auto_lr_find, lr)
else:
raise MisconfigurationException(
f'`auto_lr_find` was set to {trainer.auto_lr_find}, however'
' could not find this as a field in `model` or `model.hparams`.')
else:
if lightning_hasattr(model, 'lr'):
lightning_setattr(model, 'lr', lr)
elif lightning_hasattr(model, 'learning_rate'):
lightning_setattr(model, 'learning_rate', lr)
else:
raise MisconfigurationException(
'When auto_lr_find is set to True, expects that `model` or'
' `model.hparams` either has field `lr` or `learning_rate`'
' that can overridden')
lightning_setattr(model, lr_attr_name, lr)

log.info(f'Learning rate set to {lr}')


Expand Down
1 change: 1 addition & 0 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
""" Test that tuner algorithms are skipped if fast dev run is enabled """

model = BoringModel()
model.lr = 0.1 # avoid no-lr-found exception
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,6 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir):
train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True)

class ExtendedBoringModel(BoringModel):

def train_dataloader(self):
return train

Expand Down
11 changes: 11 additions & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
import os
from copy import deepcopy

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel
from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule

Expand Down Expand Up @@ -262,3 +264,12 @@ def test_suggestion_with_non_finite_values(tmpdir):

assert before_lr == after_lr, \
'Learning rate was altered because of non-finite loss values'


def test_lr_finder_fails_fast_on_bad_config(tmpdir):
""" Test that tune fails if the model does not have a lr BEFORE running lr find """
# note: this did not raise an exception before #5648 because lr_find is skipped
# during fast_dev_run and the lr attribute check was done after lr_find
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, auto_lr_find=True)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(MisconfigurationException, match='either has one of these fields'):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer.tune(BoringModel())