Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add self.lr_schedulers() to LightningModule for manual optimization #6567

Merged
merged 4 commits into from
Apr 9, 2021

Conversation

akihironitta
Copy link
Contributor

@akihironitta akihironitta commented Mar 17, 2021

What does this PR do?

Part of #6379.

Before disabling lr_scheduler.step() in manual optimization in #6379, this PR adds self.lr_schedulers() so that users can lr_scheduler.step() in LightningModule at arbitrary intervals in manual optimization.

Example:

class Model(LightningModule):
    def __init__(self):
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        # single scheduler
        scheduler = self.lr_schedulers()

        # multiple schedulers
        scheduler1, scheduler2 = self.lr_schedulers()

TODO

  • [n/a] Update the docs I'll update the docs in the following PR which disables lr_scheduler.step() in manual optimization because this PR itself doesn't enable users to call step() in manual optimization.
  • Add a test

Before submitting

  • [RFC] Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • [n/a] Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

Related to #6825.

@akihironitta akihironitta added the feature Is an improvement or enhancement label Mar 17, 2021
@akihironitta akihironitta added this to the 1.3 milestone Mar 17, 2021
Comment on lines +126 to +127
# ignore other keys "interval", "frequency", etc.
lr_schedulers = [s["scheduler"] for s in self.trainer.lr_schedulers]
Copy link
Contributor Author

@akihironitta akihironitta Mar 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.lr_schedulers() is supposed to be used in manual optimization, so even when dict keys like "interval" and "monitor" are defined in configure_optimizers(), this line ignores all of the keys except "scheduler". Related docs: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#learning-rate-scheduling

@akihironitta akihironitta changed the title [RFC] Add self.lr_schedulers() to LightningModule for manual optimization [RFC] Add self.lr_schedulers() to LightningModule for manual optimization [WIP] Mar 17, 2021
@akihironitta akihironitta force-pushed the feat/add-lr_schedulers-in-manopt branch from b5128e0 to 5767487 Compare April 2, 2021 09:28
@codecov
Copy link

codecov bot commented Apr 2, 2021

Codecov Report

Merging #6567 (07d25e9) into master (1bd5f36) will decrease coverage by 5%.
The diff coverage is 71%.

@@           Coverage Diff           @@
##           master   #6567    +/-   ##
=======================================
- Coverage      91%     87%    -5%     
=======================================
  Files         192     192            
  Lines       12190   12256    +66     
=======================================
- Hits        11144   10635   -509     
- Misses       1046    1621   +575     

@pep8speaks
Copy link

pep8speaks commented Apr 4, 2021

Hello @akihironitta! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-04-04 20:34:31 UTC

@akihironitta akihironitta force-pushed the feat/add-lr_schedulers-in-manopt branch from 6cd0f73 to 5767487 Compare April 4, 2021 18:44
@akihironitta akihironitta changed the title [RFC] Add self.lr_schedulers() to LightningModule for manual optimization [WIP] [RFC] Add self.lr_schedulers() to LightningModule for manual optimization Apr 4, 2021
@akihironitta akihironitta marked this pull request as ready for review April 4, 2021 20:31
@carmocca carmocca added the ready PRs ready to be merged label Apr 7, 2021
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks neat !

@tchaton tchaton merged commit 5e4dfd7 into master Apr 9, 2021
@tchaton tchaton deleted the feat/add-lr_schedulers-in-manopt branch April 9, 2021 09:32
@maxoppelt
Copy link
Contributor

maxoppelt commented May 19, 2021

I think this update introduced a new bug:

When automatic_optimization is disabled and you are using schedulers that require a metric, e.g. th.optim.lr_scheduler.ReduceLROnPlateau, an error is raised:

pytorch_lightning.utilities.exceptions.MisconfigurationException: The lr scheduler dict must include a monitor when a "ReduceLROnPlateau" scheduler is used. For example: {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "your_loss"}}

As soon as you define a monitor you get a warning:

/pytorch_lightning/utilities/distributed.py:69: RuntimeWarning: The lr scheduler dict contains the key(s) ['monitor'], but the keys will be ignored. You need to call "lr_scheduler.step()" manually in manual optimization.

Pytorch Lightning Version 1.3.1

Edit: (I do not know what the preferred solution might be, but adding)
if scheduler['reduce_on_plateau'] and scheduler.get('monitor', None) is None and not is_manual_optimization:
to L153 in optimizers.py solves this issue for me.

@awaelchli
Copy link
Contributor

@maxoppelt The lr schedulers need to be manually stepped in manual optimization. For schedulers that require the val_loss, that means they need to be stepped in the validaton_epoch_end hook. Therefore, the monitor key in the dict must be omitted and we have to turn off the error message in manual optimization. Is that correct?

@maxoppelt
Copy link
Contributor

Yes, that is a possible solution. Disable the raise of MisconfigurationException when using automatic_optimization False.

Another design choice could be: Disable the warning and provide access to the monitor key in training_epoch_end/validation_epoch_end.

Minor remark on the documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#learning-rate-scheduling-manual is misleading: Most schedulers have an epoch argument in the step method. Therefore one should not call scheduler.step() in training_step(). Especially when adding epoch as argument to your scheduler step: You get an EPOCH_DEPRECATION_WARNING.

This could lead to misunderstandings, when reading the doc. However calling the scheduler in training_epoch_end() might be problematic when using multi dataloaders or ddp training?

@awaelchli
Copy link
Contributor

Disable the raise of MisconfigurationException when using automatic_optimization False.

My preference.

Another design choice could be: Disable the warning and provide access to the monitor key in training_epoch_end/validation_epoch_end.

There is already a pattern for this, by returning the value in the step method or by using torchmetrics.
So I think we don't need another way. Or what would be concretely your suggestion?

However calling the scheduler in training_epoch_end() might be problematic when using multi dataloaders or ddp training?

For multi dataloaders, training_epoch_end() will receive outputs for all. So I see no big problem here.
In DDP training we update the scheduler in each process but when a metric is required, we probably want to update with the same metric value in all process. We need to think of something here with minimal code changes required for the user.

Are you interested sending a PR, for the error message handling / doc improvements?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants