Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using LBFGS optimizer in pytorch lightening the model is not converging as compared to native pytoch + LBFGS #4083

Closed
ghost opened this issue Oct 11, 2020 · 25 comments · Fixed by #6147
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task

Comments

@ghost
Copy link

ghost commented Oct 11, 2020

Common bugs:

Comparing the results of LBFGS + Pytorch lightening to native pytorch + LBFGS, Pytorch lightening is not able to update wights and model is not converging. there are some issues to point out:

  1. Adam + Pytorch lightening on MNIST works fine, however LBFGS + Pytorch lightening is not working as expected.
  2. LBFGS + Native pytorch works very well, however when we try LBFGS + Pytorch lightening it does not work as expected.

🐛 Bug

LBFGS + Pytorch Lightening has problem converging and weights are updating as compared to Adam + Pytorch lightening.

Code sample

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms,datasets
from torch.utils.data import DataLoader,random_split
import pytorch_lightning as pl 
from IPython.display import clear_output

class LightningMNISTClassifier(pl.LightningModule):
  def __init__(self):
    super(LightningMNISTClassifier,self).__init__()
    self.layer_1 = nn.Linear(28 * 28, 128)
    self.layer_2 = nn.Linear(128, 256)
    self.layer_3 = nn.Linear(256, 10)
    
  def forward(self, x):
    batch_size, channels, width, height = x.size()
    x=x.view(batch_size,-1)
    # layer 1
    x = self.layer_1(x)
    x = torch.relu(x)
    # layer 2
    x = self.layer_2(x)
    x = torch.relu(x) 
    # layer 3
    x = self.layer_3(x)
    # probability distribution over labels
    x = torch.log_softmax(x, dim=1)  
    return x 
  def prepare_data(self):
    transform=transforms.Compose([transforms.ToTensor(), 
                                  transforms.Normalize((0.1307,), (0.3081,))])
    # prepare transforms standard to MNIST
    mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
    mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)  
    self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])

  def train_dataloader(self):
    return DataLoader(self.mnist_train,batch_size=1024)
 
  # def val_dataloader(self):
  #   return DataLoader(self.mnist_val,batch_size=1024)
  # def test_dataloader(self):
  #   return DataLoader(self.mnist_test,batch_size=1024)


  def configure_optimizers(self):
    # optimizer=optim.Adam(self.parameters(),lr=1e-3)
    optimizer = optim.LBFGS(self.parameters(), lr=1e-2)
    return optimizer

  # def backward(self, trainer, loss, optimizer):
  #   loss.backward(retain_graph=True)


  def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx,
                     second_order_closure, on_tpu=False, using_native_amp=False,
                     using_lbfgs=False):
        # update params
      optimizer.step(second_order_closure) 

  def cross_entropy_loss(self,logits,labels):
    return F.nll_loss(logits,labels)

  def training_step(self,train_batch,batch_idx):
    x,y=train_batch
    logits=self.forward(x)
    loss=self.cross_entropy_loss(logits,y)
    return  {'loss':loss}

  def training_epoch_end(self,outputs):
    avg_loss=torch.stack([x['loss'] for x in outputs]).mean()
    print('epoch={}, avg_Train_loss={:.2f}'.format(self.current_epoch,avg_loss.item()))
    # return {'avg_train_loss':avg_loss}

  # def validation_step(self,val_batch,batch_idx):
  #   x,y=val_batch
  #   logits=self.forward(x)
  #   loss=self.cross_entropy_loss(logits,y)
  #   return {'val_loss':loss}
  # def validation_epoch_end(self,outputs):
  #   avg_loss=torch.stack([x['val_loss'] for x in outputs]).mean()
  #   print('epoch={}, avg_Test_loss={:.2f}'.format(self.current_epoch,avg_loss.item()))
  #   return {'avg_val_loss':avg_loss}

model=LightningMNISTClassifier()
#from pytorch_lightning.callbacks import EarlyStopping
trainer=pl.Trainer(max_epochs=400,gpus=1,
                  #  check_val_every_n_epoch=2,
                  #  accumulate_grad_batches=5,
#                   early_stop_callback=early_stop,
                  #  limit_train_batches=50,
#                   val_check_interval=0.25,
                   progress_bar_refresh_rate=0,
#                   num_sanity_val_steps=0,
                   weights_summary=None)
clear_output(wait=True)
trainer.fit(model)Preformatted text.

Expected behavior

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py

Environment:
-Colab and pycharm
-PyTorch version: 1.6.0+CPU and GPU
-pytorch-lightning==1.0.0rc3

@ghost ghost added bug Something isn't working help wanted Open to be worked on labels Oct 11, 2020
@rohitgr7
Copy link
Contributor

do you have the code for native PyTorch + LBFGS for the same?

@ghost
Copy link
Author

ghost commented Oct 11, 2020

this is the code including MNIST and LBFGS that works fine with native pytorch:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms,datasets
from torch.utils.data import DataLoader,random_split


class PytorchMNISTClassifier(nn.Module):
  def __init__(self):
    super(PytorchMNISTClassifier,self).__init__()
    self.layer_1 = nn.Linear(28 * 28, 128)
    self.layer_2 = nn.Linear(128, 256)
    self.layer_3 = nn.Linear(256, 10)
  def forward(self, x):
    batch_size, channels, width, height = x.size()
    x=x.view(batch_size,-1)
    # layer 1
    x = self.layer_1(x)
    x = torch.relu(x)
    # layer 2
    x = self.layer_2(x)
    x = torch.relu(x) 
    # layer 3
    x = self.layer_3(x)
    # probability distribution over labels
    x = torch.log_softmax(x, dim=1)  
    return x 

def cross_entropy_loss(logits,labels):
  return F.nll_loss(logits,labels)

if __name__ == '__main__':

  if torch.cuda.is_available():
    device=torch.device('cuda:0')
  else:
    device=torch.device('cpu')

  model=PytorchMNISTClassifier()
  model=model.to(device)
  # optimizer=optim.Adam(model.parameters(),lr=1e-3)
  optimizer = optim.LBFGS(model.parameters(),lr=0.01)

  transform=transforms.Compose([transforms.ToTensor(), 
                                  transforms.Normalize((0.1307,), (0.3081,))])
  # prepare transforms standard to MNIST
  mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
  mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)  
  mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

  data=DataLoader(mnist_train,batch_size=1024)

  for Epoch in range(10):
    loss_total=0.
    for i,(x,y) in enumerate(data):
      x=x.to(device)
      y=y.to(device)
      def closure():
        logits=model(x)
        optimizer.zero_grad()
        loss=cross_entropy_loss(logits,y)
        loss.backward(retain_graph=True)
        return loss
    loss_out = optimizer.step(closure)
    loss_total+=loss_out.item()
    print('total_loss--->', loss_total)

@williamFalcon
Copy link
Contributor

You don't need to override optimizer_step... you're only doing it to pass in the second_order closure, but that's exactly what the default implementation does

        if on_tpu:
            xm.optimizer_step(optimizer)
        elif using_native_amp:
            self.trainer.scaler.step(optimizer)
        elif using_lbfgs:
            optimizer.step(second_order_closure)
        else:
            optimizer.step()

@rohitgr7
Copy link
Contributor

@williamFalcon still should converge right?? even if the overridden method is doing the same update. Maybe a bug here if it's not converging in pl. will check this.

@ghost
Copy link
Author

ghost commented Oct 12, 2020

@williamFalcon we modified the code by removing optimizer_step, however it dose not help solving the issue.

@rohitgr7
Copy link
Contributor

ok found something. not sure if it's correct or not since I haven't used LBFGS before.

I checked that optim.LBFGS calls closure 20 times for each step and in this example it doesn't call any step and .backward() explicitly but relies on optimizer.step(closure) to do that. Also in every 20 steps the underlying loss is different.

But pl calls an explicit training_step with the closure obviously that means it will be called 21 times + an explicit loss.backward() is called always.

These are my observations. Anyone with prior experience with LBFGS optimizer can confirm the right way to do this.

@williamFalcon
Copy link
Contributor

williamFalcon commented Oct 12, 2020

how many times does it get called with pytorch?

LBFGS is a quasi knewton method which means it does not compute the hessian directly but instead it approximates it.

I assume pytorch calls step multiple times to do this approximation?

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 12, 2020

how many times does it get called with pytorch?

the given example calls it 20 times. I think it always calls it 20 times, checked a few examples.

@ghost
Copy link
Author

ghost commented Oct 12, 2020

the default value for the number of iterations is 20 times , based on the pytorch help:

torch.optim.LBFGS(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn=None)

@ghost
Copy link
Author

ghost commented Oct 15, 2020

@williamFalcon We are in the process of developing a code that requires me to use LBFGS optimizer. I'd like to use pytorch-lightening platform for this code. do you think that the LBFGS issue can be resolved any time soon in the later versions?

@edenlightning edenlightning added this to the 1.0.3 milestone Oct 19, 2020
@justusschock
Copy link
Member

@rohitgr7 unfortunately it does not seem to be fixed by #4190 even though the number of backward calls are now correct (there is a test for that). The loss is still not decreasing though (haven't investigated further)

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 21, 2020

ok will check this if I get some time :)

@ghost
Copy link
Author

ghost commented Oct 27, 2020

@williamFalcon it seems that the LBFGS optimizer in the latest version of pytorch-lightening carries the same issue as the previous versions. Is there a way to fix this issue temporarily up to the time that bug gets fixed.

@edenlightning edenlightning modified the milestones: 1.0.x, 1.0.7 Nov 10, 2020
@Borda Borda modified the milestones: 1.0.7, 1.0.x Nov 11, 2020
@edenlightning edenlightning modified the milestones: 1.0.x, 1.0.7 Nov 13, 2020
@Borda Borda modified the milestones: 1.0.7, 1.0.x Nov 13, 2020
@edenlightning edenlightning added the priority: 2 Low priority task label Nov 17, 2020
@Borda Borda modified the milestones: 1.0.x, 1.1 Nov 18, 2020
@edenlightning edenlightning removed this from the 1.1 milestone Nov 18, 2020
@ghost
Copy link
Author

ghost commented Nov 27, 2020

@Borda , @edenlightning ,LBFGS issue dose not seem be fixed in the latest version of pytorch Lightening. should we hope that this issue could be fixed in the near future? we started a project using pytorch lightening and got stuck because of not being able to use LBFGS optimizer. if it is not fixed yet, would be possible to expedite resolving this issue?

@stale
Copy link

stale bot commented Dec 27, 2020

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Dec 27, 2020
@Bajo1994
Copy link

Bajo1994 commented Jan 1, 2021

@williamFalcon @Borda @edenlightning Since this thread will be closed automatically within the next 48 hours, I decided to mention you guys with the hope that the bug gets fixed in a meaningful period. I also appreciate @justusschock for his efforts to fix the issue. Ignoring a bug will not fix it, and it dramatically stops the research activities of people who trusted lightning. Please help us with fixing the bug.

@stale stale bot removed the won't fix This will not be worked on label Jan 1, 2021
@Bajo1994
Copy link

@carmocca I am very thankful if you take a look at the discussion made here to see whether you can help us fix the issue. The LBFGS bug in lightning has dramatically impacted an important project that I am working on.

@tchaton tchaton added the priority: 1 Medium priority task label Jan 13, 2021
@carmocca carmocca removed the priority: 2 Low priority task label Jan 13, 2021
@edenlightning
Copy link
Contributor

edenlightning commented Jan 14, 2021

Apologize for the delay! We try our best to take a look at every issue with the resources that we have. We bumped the priority for this one and will try to prioritize in the next sprints!

@Bajo1994
Copy link

@edenlightning I greatly appreciate your help on this subject.

@akihironitta
Copy link
Contributor

akihironitta commented Jan 15, 2021

As @justusschock added the tests in #4190 and I confirmed locally with cProfile, the number of backward passes (the number of times closure was called) in PL is 20 which is the same as native PyTorch, so this should be no problem.

PL code example (originally from @peymanpoozesh)
import os
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

import pytorch_lightning as pl

warnings.filterwarnings("ignore")
pl.seed_everything(42)


class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(LightningMNISTClassifier, self).__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)
        return x

    def prepare_data(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
        self.mnist_train, self.mnist_val = random_split(
            mnist_train, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        dl = DataLoader(self.mnist_train, batch_size=1024, num_workers=0)
        return dl

    def configure_optimizers(self):
        # optimizer = optim.Adam(self.parameters(), lr=1e-3)
        optimizer = optim.LBFGS(self.parameters(), lr=0.01, max_iter=20)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        return {"loss": loss}

    def training_step_end(self, outputs):
        print("closure_loss:", outputs["loss"].item())
        return outputs


def main():
    model = LightningMNISTClassifier()
    trainer = pl.Trainer(
        max_epochs=30,
        progress_bar_refresh_rate=0,
        weights_summary=None,
        # fast_dev_run=20,
    )
    trainer.fit(model)


if __name__ == "__main__":
    main()
native PyTorch code example (originally from @peymanpoozesh)
import os
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

from pytorch_lightning import seed_everything

warnings.filterwarnings("ignore")
seed_everything(42)


class PytorchMNISTClassifier(nn.Module):
    def __init__(self):
        super(PytorchMNISTClassifier, self).__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)
        return x


def main():
    device = torch.device("cpu")
    model = PytorchMNISTClassifier().to(device)

    # optimizer=optim.Adam(model.parameters(),lr=1e-3)
    optimizer = optim.LBFGS(model.parameters(), lr=0.01, max_iter=20)

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
    mnist_train, mnist_val = random_split(
        mnist_train, [55000, 5000], generator=torch.Generator().manual_seed(42)
    )

    dl = DataLoader(mnist_train, batch_size=1024, num_workers=0)

    for epoch in range(30):
        for i, (x, y) in enumerate(dl):
            x = x.to(device)
            y = y.to(device)

            def closure():
                logits = model(x)
                optimizer.zero_grad()
                loss = F.nll_loss(logits, y)
                loss.backward(retain_graph=True)
                print("closure_loss:", loss.item())
                return loss

            loss_out = optimizer.step(closure=closure)


if __name__ == "__main__":
    main()
my env
$ wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
$ python collect_env_details.py
* CUDA:
	- GPU:
	- available:         False
	- version:           None
* Packages:
	- numpy:             1.19.5
	- pyTorch_debug:     False
	- pyTorch_version:   1.7.1+cpu
	- pytorch-lightning: 1.1.4
	- tqdm:              4.56.0
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         
	- python:            3.8.5
	- version:           #1 SMP Debian 4.19.160-2 (2020-11-28)

I have no idea how I could investigate this further. @carmocca @rohitgr7 Could you help here if you have time...?


EDIT (Jan 28, 2021): Not sure how this helps us debug, but I realised that if we change the value of torch.optim.LBFGS(..., max_iter=20) from 20 (by default) to 1 or 2, both PL and native PyTorch behave exactly the same which I confirmed with my example code above. (Both don't converge though.)

@akihironitta
Copy link
Contributor

akihironitta commented Feb 21, 2021

@peymanpoozesh @Bajo1994 Sorry for the delay. I haven't figured out why LBFGS behaves differently between Lightning and native PyTorch, but I found an easy workaround, so let me share it here.

The workaround is to use the manual optimization instead of the default automatic optimization (See my notebook linked below for the complete code using BoringModel):

class Model(pl.LightningModule):
    def __init__(self, ...):
        self.automatic_optimization = False
        ...

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        def closure():
            output = self.layer(batch)
            loss = self.loss(batch, output)
            optimizer.zero_grad()
            self.manual_backward(loss)
            return loss
        optimizer.step(closure=closure)

See also:
Full code of the workaround - Google Colab
Manual optimization - PyTorch Lightning Docs

@akihironitta
Copy link
Contributor

akihironitta commented Feb 22, 2021

Here are the minimal code examples using BoringModel.

Lightning code
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Dataset
pl.seed_everything(42)

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return self.len

class BoringModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
    def forward(self, x):
        return self.layer(x)
    def loss(self, batch, prediction):
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}
    def training_step_end(self, training_step_outputs):
        loss = training_step_outputs["loss"]
        print("loss:", loss.item())
        return training_step_outputs
    def configure_optimizers(self):
        # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        optimizer = torch.optim.LBFGS(self.parameters(), lr=0.01, max_iter=20)
        return optimizer

def main():
    ds = RandomDataset(32, 100000)
    dl = DataLoader(ds, batch_size=1024)
    model = BoringModel()
    trainer = pl.Trainer(
        progress_bar_refresh_rate=0,
        fast_dev_run=1,
    )
    trainer.fit(model, dl)

if __name__ == "__main__":
    main()
Pure PyTorch code
import torch
import torch.nn as nn
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader, Dataset
seed_everything(42)

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return self.len

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer = torch.nn.Linear(32, 2)
    def forward(self, x):
        return self.layer(x)

def main():
    ds = RandomDataset(32, 100000)
    dl = DataLoader(ds, batch_size=1024)
    model = Model()
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    optimizer = torch.optim.LBFGS(model.parameters(), lr=0.01, max_iter=20)
    for epoch in range(3):
        for i, x in enumerate(dl):
            def closure():
                prediction = model(x)
                loss = torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
                optimizer.zero_grad()  # removing this line causes the same bug as in Lightning script
                loss.backward()
                print("loss:", loss.item())
                return loss
            loss_out = optimizer.step(closure=closure)

if __name__ == '__main__':
    main()

cc: @carmocca @tchaton

@justusschock
Copy link
Member

justusschock commented Feb 22, 2021

@akihironitta Why doesn't optimizer.step(closure=closure) work? Why do you have to unwrap it?
Because without unwrapping you also get all the precision support from lightning :)

@akihironitta
Copy link
Contributor

@justusschock Fixed! (It was just for print debugging from another script because LightningOptimizer doesn't return the output of closure())

@Bajo1994
Copy link

Bajo1994 commented Feb 28, 2021

@akihironitta @carmocca I am very thankful for your great effort on this bug. I am looking forward to resuming my project as soon as you update the pl package. In my code, I like to switch between LBFGS and Adam optimizers. I like to use the LBFGS when the loss is large and then switch to Adam. I hope switching between these two optimizers would be smooth in pl (I had difficulties in switching between these two optimizers in native PyTorch). I will keep you posted if there is any problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants