Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizers are broken with auto_lr_find=True since 1.1.4 #6285

Closed
indigoviolet opened this issue Mar 2, 2021 · 4 comments · Fixed by #6372
Closed

Optimizers are broken with auto_lr_find=True since 1.1.4 #6285

indigoviolet opened this issue Mar 2, 2021 · 4 comments · Fixed by #6372
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task tuner

Comments

@indigoviolet
Copy link

🐛 Bug

It seems like #5244 (which went out with 1.1.4) caused some bad interaction with auto_lr_find=True.

Specifically, lightning_optimizers are now cached on the Trainer. However, if we update the lr with auto_lr_find=True, we would expect the optimizers returned from configure_optimizers to change -- so that the lightning_optimizers need to be updated -- but this is no longer handled because we no longer re-wrap the optimizers in the general case.

The outcome for me is that training just doesnt converge because we're updating the wrong optimizer.

Please reproduce using the BoringModel

https://colab.research.google.com/drive/1PJGOBSUdl5_-U9O-fvo83V1On6_siwAC?usp=sharing

To Reproduce

See the colab^

Expected behavior

Training should work!

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: False
    • pyTorch_version: 1.7.1+cu101
    • pytorch-lightning: 1.2.1
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.10
    • version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020

Additional context

  1. This was a pretty frustrating bug to track down, it broke training on my model in a super unconnected way and I had to literally git bisect both my repo and pytorch-lightning's repo to find it.

  2. It's scary to me that the bug seems to have gone unnoticed for so many versions -- does no one use auto_lr_find=True? Are there no test cases checking this combination?

@indigoviolet indigoviolet added bug Something isn't working help wanted Open to be worked on labels Mar 2, 2021
@indigoviolet
Copy link
Author

fyi @ananthsub, since you might remember the original code

@awaelchli
Copy link
Contributor

Thanks for reporting this.

import os

from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
tmpdir = os.getcwd()

import torch
from pytorch_lightning import LightningModule


def check_optimizer(trainer, when):
    assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0], when


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.lr = 0.

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log("loss", loss)
        return loss

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        return optimizer


class OptimizerSpy(pl.callbacks.Callback):
    def on_fit_start(self, trainer, *args, **kwargs):
        # UNCOMMENT TO FIX BUG
        # trainer._lightning_optimizers = None
        check_optimizer(trainer, "on_fit_start")


if __name__ == "__main__":
    model = BoringModel()
    num_samples = 10000
    train = RandomDataset(32, num_samples)
    train = DataLoader(train, batch_size=32)

    trainer = pl.Trainer(
        max_epochs=1,
        auto_lr_find=True,
        callbacks=[OptimizerSpy()]
    )
    trainer.tune(model, train)
    check_optimizer(trainer, "after tune")
    trainer.fit(model, train)
    check_optimizer(trainer, "after fit")

Minimal repro code based on code provided by @indigoviolet
The comment in on_fit_start callback shows what we need to fix.

@awaelchli
Copy link
Contributor

@indigoviolet I propose this fix here: #6372
let me know if that solves your issues.

@indigoviolet
Copy link
Author

Thanks for the fix, @awaelchli!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task tuner
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants