Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

self.automatic_optimization = False prevents checkpoints from being saved #13674

Closed
anicolson opened this issue Jul 15, 2022 · 4 comments
Closed
Labels
checkpointing Related to checkpointing loops Related to the Loop API optimization question Further information is requested

Comments

@anicolson
Copy link

anicolson commented Jul 15, 2022

🐛 Bug

Hi, I am using Sharpness-Aware Minimization (SAM) (as implemented here: https://github.com/davda54/sam). To implement SAM in PTL, self.automatic_optimization = False is needed in the init of the LightningModule, as the steps described here are performed twice as part of SAM.

The issue that I am facing is that checkpoints are not saved in lightning_logs if I use self.automatic_optimization = False in the init of the LightningModule.

Please let me know if there is something basic that I am missing as I could not find anything regarding this issue here: https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html#manual-optimization.

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.automatic_optimization = False

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Expected behavior

If self.automatic_optimization = False, checkpoints are not saved in lightning_logs. If it is removed, checkpoints are saved in lightning_logs.

Environment

  • CUDA:
    • GPU:
      • Tesla P100-SXM2-16GB
      • Tesla P100-SXM2-16GB
      • Tesla P100-SXM2-16GB
      • Tesla P100-SXM2-16GB
    • available: True
    • version: 10.2
  • Packages:
    • numpy: 1.22.4
    • pyTorch_debug: False
    • pyTorch_version: 1.12.0+cu102
    • pytorch-lightning: 1.6.5
    • tqdm: 4.64.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.9.4
    • version: Proposal for help #1 SMP Tue Apr 5 12:47:31 UTC 2022 (d77db66)

Thanks in advance

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @Borda @carmocca @justusschock

@anicolson anicolson added the needs triage Waiting to be triaged by maintainers label Jul 15, 2022
@awaelchli
Copy link
Contributor

Hi @anicolson
I'm able to reproduce this. However, you are not stepping the optimizer, are you aware of this?
Here is what you'd need to do realistically, in manual optimization:

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        loss = self(batch).sum()
        loss.backward()
        opt.step()
        opt.zero_grad()
        self.log("train_loss", loss)

Full code

import os

from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


import torch


class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.automatic_optimization = False

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        loss = self(batch).sum()
        loss.backward()
        opt.step()
        opt.zero_grad()
        self.log("train_loss", loss)

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return SAM(self.parameters(), torch.optim.SGD, lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

I included the code from the SAM optimizer to make sure it works with it too.

I couldn't find it yet, but there must be a condition somewhere in the loops that checks whether the optimizer has stepped or not for checkpointing. Maybe @carmocca remembers it. I can't remember if there was a good reason for this.

@awaelchli awaelchli added checkpointing Related to checkpointing optimization loops Related to the Loop API bug Something isn't working and removed needs triage Waiting to be triaged by maintainers labels Jul 22, 2022
@anicolson
Copy link
Author

Hi @awaelchli, thank you for your reply.

Yes, I am aware that the following should happen in the training_step if self.automatic_optimization = False:

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        opt.zero_grad()
        loss = self.compute_loss(batch)
        self.manual_backward(loss)
        opt.step()

However, for SAM, the recommended training loop is something like this (where there are two forward-backward passes) with, again, self.automatic_optimization = False (from the README.md of https://github.com/davda54/sam and a few of the issues from that repo):

    def training_step(self, batch, batch_idx):
        """
        Training step (the training loss needs to be returned).

        Argument/s:
            batch - mini-batch from the training set DataLoader.
            batch_idx - batch idx of each example in the mini-batch.

        Returns:
            loss - training loss for the mini-batch.
        """

        # Mini-batch of examples
        images, labels = batch

        # Get optimiser
        opt = self.optimizers()

        # First forward-backward pass for SAM
        enable_running_stats(self)  # https://github.com/davda54/sam/issues/30#issuecomment-909712587
        y_hat = self(images)
        loss_1 = self.loss(y_hat['logits'], labels)
        with self.trainer.model.no_sync():  # https://github.com/davda54/sam/issues/38
            self.manual_backward(loss_1)
        opt.first_step(zero_grad=True)

        # Second forward-backward pass for SAM
        disable_running_stats(self)  # https://github.com/davda54/sam/issues/30#issuecomment-909712587
        y_hat = self(images)
        loss_2 = self.loss(y_hat['logits'], labels)
        self.manual_backward(loss_2)
        opt.second_step(zero_grad=True)

        # Log loss
        losses = {'train_loss_step_1': loss_1, 'train_loss_step_2': loss_2}
        self.log_dict(losses, on_step=False, on_epoch=True, batch_size=images.size()[0])

        return loss_1

and

def disable_running_stats(model):
    def _disable(module):
        if isinstance(module, nn.BatchNorm2d):
            module.backup_momentum = module.momentum
            module.momentum = 0

    model.apply(_disable)

def enable_running_stats(model):
    def _enable(module):
        if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
            module.momentum = module.backup_momentum

    model.apply(_enable)

The optimiser is stepped in second_step of the SAM object:

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

Okay, I just tested my code from the initial comment with the addition of opt = self.optimizers() and opt.step() in training_step and the checkpoint saves. I don't know why I didn't try this before...

It would be great to find that condition so that something like SAM can be used with PTL.

Thanks again for your help.

@carmocca
Copy link
Contributor

@anicolson
Copy link
Author

Thank you both for your help.

It saves checkpoints with SAM if you wrap the first forward-backward pass in a closure and pass it to the step method of SAM:

    def training_step(self, batch, batch_idx):

        # Get optimiser
        opt = self.optimizers()

        # First forward-backward pass for SAM
        def closure():
            enable_running_stats(self)  # https://github.com/davda54/sam/issues/30#issuecomment-909712587
            loss_1 = self(batch).sum()
            with self.trainer.model.no_sync():  # https://github.com/davda54/sam/issues/38
                self.manual_backward(loss_1)
            self.log_dict({'train_loss_step_1': loss_1}, on_step=False, on_epoch=True, batch_size=batch.size()[0])
            return loss_1

        # Second forward-backward pass for SAM
        disable_running_stats(self)  # https://github.com/davda54/sam/issues/30#issuecomment-909712587
        loss_2 = self(batch).sum()
        self.manual_backward(loss_2)
        opt.step(closure)
        opt.zero_grad()
        self.log_dict({'train_loss_step_2': loss_2}, on_step=False, on_epoch=True, batch_size=batch.size()[0])

Happy for this issue to be closed.

@awaelchli awaelchli added question Further information is requested and removed bug Something isn't working labels Jul 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing loops Related to the Loop API optimization question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants