Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/5487 auto lr ordering #5638

Merged
merged 35 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7952b79
started to write failing test. just getting into the framework...
noamzilo Jan 15, 2021
7e2fa62
started to write failing test. just getting into the framework...
noamzilo Jan 17, 2021
edc9ed9
added failing test for misconfiguration of lr finder
noamzilo Jan 23, 2021
d24b797
made test startup quickly. making sure without the fix it also fails …
noamzilo Jan 24, 2021
b51a754
improved test
noamzilo Jan 24, 2021
0ca79c8
Merge branch 'master' into bugfix/5487_auto_lr_ordering
noamzilo Jan 24, 2021
c34042c
fixed for linter
noamzilo Jan 24, 2021
0ccefb2
Merge branch 'bugfix/5487_auto_lr_ordering' of https://github.com/noa…
noamzilo Jan 24, 2021
e7a6c36
fixed for linter
noamzilo Jan 24, 2021
874178c
yet another fix for the linter
noamzilo Jan 24, 2021
404bc9c
yet another fix for the linter
noamzilo Jan 24, 2021
7f1bae7
Merge branch 'master' into bugfix/5487_auto_lr_ordering
noamzilo Jan 25, 2021
c6dedc8
fixed comment by @carmocca
noamzilo Jan 25, 2021
03756f2
fixed comment by @carmocca
noamzilo Jan 25, 2021
d9042e0
Merge branch 'master' into bugfix/5487_auto_lr_ordering
noamzilo Jan 27, 2021
8e6862d
Fix test
carmocca Jan 28, 2021
cc81a31
Merge branch 'master' into bugfix/5487_auto_lr_ordering
carmocca Jan 28, 2021
78500a4
chlog
Borda Jan 29, 2021
6e9382a
Apply suggestions from code review
Borda Jan 29, 2021
70eb21f
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 29, 2021
3bc6ae1
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 29, 2021
fa7255e
Fix test
carmocca Jan 29, 2021
ecbc80d
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 29, 2021
bc6055c
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 30, 2021
8d01dd3
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 31, 2021
7ee4056
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Jan 31, 2021
15ec3cd
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
1008706
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
030d6e8
Update tests/trainer/test_lr_finder.py
carmocca Feb 1, 2021
70359a3
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
99f09f7
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
cff6bc6
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Feb 1, 2021
035b72f
Update pytorch_lightning/tuner/lr_finder.py
carmocca Feb 1, 2021
71b36d2
Update tests/trainer/test_lr_finder.py
carmocca Feb 1, 2021
042bfda
Merge branch 'master' into bugfix/5487_auto_lr_ordering
mergify[bot] Feb 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import importlib
import os
from typing import List, Optional, Sequence, Union, Callable
from typing import List, Optional, Sequence, Union, Callable, Any
from functools import wraps

import numpy as np
Expand All @@ -40,8 +40,28 @@
from tqdm import tqdm


def __choose_lr_assigner(trainer, model: LightningModule) -> Callable[[Any], None]:
noamzilo marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(trainer.auto_lr_find, str):
if not lightning_hasattr(model, trainer.auto_lr_find):
raise MisconfigurationException(
f'`auto_lr_find` was set to {trainer.auto_lr_find}, however'
' could not find this as a field in `model` or `model.hparams`.')
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return lambda val: lightning_setattr(model, trainer.auto_lr_find, val)

if lightning_hasattr(model, 'lr'):
return lambda val: lightning_setattr(model, 'lr', val)
if lightning_hasattr(model, 'learning_rate'):
return lambda val: lightning_setattr(model, 'learning_rate', val)
raise MisconfigurationException(
'When auto_lr_find is set to True, expects that `model` or'
' `model.hparams` either has field `lr` or `learning_rate`'
' that can overridden')


def _run_lr_finder_internally(trainer, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_assigner = __choose_lr_assigner(trainer, model)

lr_finder = lr_find(trainer, model)

if lr_finder is None:
Expand All @@ -50,24 +70,8 @@ def _run_lr_finder_internally(trainer, model: LightningModule):
lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
if isinstance(trainer.auto_lr_find, str):
# Try to find requested field, may be nested
if lightning_hasattr(model, trainer.auto_lr_find):
lightning_setattr(model, trainer.auto_lr_find, lr)
else:
raise MisconfigurationException(
f'`auto_lr_find` was set to {trainer.auto_lr_find}, however'
' could not find this as a field in `model` or `model.hparams`.')
else:
if lightning_hasattr(model, 'lr'):
lightning_setattr(model, 'lr', lr)
elif lightning_hasattr(model, 'learning_rate'):
lightning_setattr(model, 'learning_rate', lr)
else:
raise MisconfigurationException(
'When auto_lr_find is set to True, expects that `model` or'
' `model.hparams` either has field `lr` or `learning_rate`'
' that can overridden')
lr_assigner(lr)

log.info(f'Learning rate set to {lr}')


Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,6 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir):
train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True)

class ExtendedBoringModel(BoringModel):

def train_dataloader(self):
return train

Expand Down
30 changes: 30 additions & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning import seed_everything
from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
from tests.base import BoringModel, RandomDataset
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data.dataloader import DataLoader


def test_error_on_more_than_1_optimizer(tmpdir):
Expand Down Expand Up @@ -262,3 +266,29 @@ def test_suggestion_with_non_finite_values(tmpdir):

assert before_lr == after_lr, \
'Learning rate was altered because of non-finite loss values'


@pytest.mark.timeout(1)
def test_lr_finder_fails_fast_on_bad_config(tmpdir):
""" Test that misconfiguration of learning_rate or lr in model fails BEFORE lr optimization and not after it. """
import time
train = RandomDataset(32, 64)
context = 'spawn'
train = DataLoader(train, batch_size=32, num_workers=1, multiprocessing_context=context, shuffle=True)

class ExtendedBoringModel(BoringModel):
def train_dataloader(self):
return train
model = ExtendedBoringModel()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
auto_lr_find=True,
deterministic=True,
overfit_batches=1,
)
train_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset))

with pytest.raises(MisconfigurationException):
trainer.tune(model, train_dataloader=train_data_loader)