Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Scheduler monitor for manual optimization #7643

Merged
merged 33 commits into from
Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
41da9d2
Bugfix for manual optimization and scheduler without monitor.
maxoppelt May 21, 2021
fa8a278
Added minimal example using a learning rate scheduler that requires a…
maxoppelt May 21, 2021
6f39b10
Another exception, if scheduler is not passed as dict.
maxoppelt May 21, 2021
f34edd0
Fixed testcode type.
maxoppelt May 21, 2021
6fc1c6e
Fixed description.
maxoppelt May 21, 2021
3c7bc39
Fixed package pytorch name.
maxoppelt May 21, 2021
03a2904
Fixed: Code Style Guide Pytorch Lightning
maxoppelt May 21, 2021
6a7e300
Combine logic for manual optimization into a seperate case.
maxoppelt May 21, 2021
0720cf6
Merged automatic commit to seperate case.
maxoppelt May 21, 2021
5edaa69
Update docs/source/common/optimizers.rst
maxoppelt May 24, 2021
5187ac1
Bugfix for manual optimization and scheduler without monitor.
maxoppelt May 21, 2021
147d0e9
Added minimal example using a learning rate scheduler that requires a…
maxoppelt May 21, 2021
cf4876f
Another exception, if scheduler is not passed as dict.
maxoppelt May 21, 2021
2ea560f
Fixed testcode type.
maxoppelt May 21, 2021
495a388
Fixed description.
maxoppelt May 21, 2021
176b5bb
Fixed package pytorch name.
maxoppelt May 21, 2021
6a0ba3f
Combine logic for manual optimization into a seperate case.
maxoppelt May 21, 2021
1a681ca
Fixed: Code Style Guide Pytorch Lightning
maxoppelt May 21, 2021
8a1d3f3
Added test for learning rate scheduling when automatic differentiatio…
maxoppelt Jun 25, 2021
e9002a8
Merged remote changes for schedulers when disabeling automatic differ…
maxoppelt Jun 25, 2021
f89c38c
Merge branch 'master' into bugfix/scheduler_manual_optimization
maxoppelt Jul 1, 2021
ae47f17
Merge branch 'master' into bugfix/scheduler_manual_optimization
maxoppelt Jul 5, 2021
547905e
Merge branch 'master' into bugfix/scheduler_manual_optimization
kaushikb11 Jul 27, 2021
f487591
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2021
bcabbcd
Update test
kaushikb11 Jul 27, 2021
60d8db9
Merge branch 'bugfix/scheduler_manual_optimization' of https://github…
kaushikb11 Jul 27, 2021
3bbedc7
Update docs
kaushikb11 Jul 27, 2021
30f98f4
Update tests/trainer/optimization/test_manual_optimization.py
kaushikb11 Jul 27, 2021
667d92f
Update docs
kaushikb11 Jul 27, 2021
aa11ea5
Merge branch 'bugfix/scheduler_manual_optimization' of https://github…
kaushikb11 Jul 27, 2021
6d4bfeb
Apply suggestions from code review
Borda Jul 27, 2021
42e2c28
Update test
kaushikb11 Jul 27, 2021
522408f
Merge branch 'bugfix/scheduler_manual_optimization' of https://github…
kaushikb11 Jul 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,23 @@ If you want to call ``lr_scheduler.step()`` every ``n`` steps/epochs, do the fol
if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % n == 0:
sch.step()

If you want to call schedulers that require a metric value after each epoch, consider doing the following:

.. testcode::

def __init__(self):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.automatic_optimization = False

def training_epoch_end(self, outputs):
sch = self.lr_schedulers()

# If the selected scheduler is a ReduceLROnPlateau scheduler.
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
loss = th.stack([x['loss'] for x in outputs]).mean()
maxoppelt marked this conversation as resolved.
Show resolved Hide resolved
loss = self.all_gather(loss).mean()
sch.step(loss)

-----

Improve training speed with model toggling
Expand Down
29 changes: 17 additions & 12 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,23 @@ def configure_schedulers(
lr_schedulers = []
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
if isinstance(scheduler, dict):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if is_manual_optimization:
if isinstance(scheduler, dict):
invalid_keys = {'interval', 'frequency', 'reduce_on_plateau', 'monitor', 'strict'}
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]

if keys_to_warn:
rank_zero_warn(
f'The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored.'
' You need to call `lr_scheduler.step()` manually in manual optimization.',
RuntimeWarning,
)

scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
lr_schedulers.append({**default_config, **scheduler})
else:
lr_schedulers.append({**default_config, 'scheduler': scheduler})
elif isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
Expand All @@ -136,17 +152,6 @@ def configure_schedulers(
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
if is_manual_optimization:
invalid_keys = {'interval', 'frequency', 'reduce_on_plateau', 'monitor', 'strict'}
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]

if keys_to_warn:
rank_zero_warn(
f'The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored.'
' You need to call `lr_scheduler.step()` manually in manual optimization.',
RuntimeWarning,
)

scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)
Expand Down