-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
self.automatic_optimization = False prevents checkpoints from being saved #13674
Comments
Hi @anicolson def training_step(self, batch, batch_idx):
opt = self.optimizers()
loss = self(batch).sum()
loss.backward()
opt.step()
opt.zero_grad()
self.log("train_loss", loss) Full code import os
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
import torch
class SAM(torch.optim.Optimizer):
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)
@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None: continue
self.state[p]["old_p"] = p.data.clone()
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum "w + e(w)"
if zero_grad: self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None: continue
p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
self.base_optimizer.step() # do the actual "sharpness-aware" update
if zero_grad: self.zero_grad()
@torch.no_grad()
def step(self, closure=None):
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
self.first_step(zero_grad=True)
closure()
self.second_step()
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack([
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]),
p=2
)
return norm
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.automatic_optimization = False
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
opt = self.optimizers()
loss = self(batch).sum()
loss.backward()
opt.step()
opt.zero_grad()
self.log("train_loss", loss)
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return SAM(self.parameters(), torch.optim.SGD, lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == "__main__":
run() I included the code from the SAM optimizer to make sure it works with it too. I couldn't find it yet, but there must be a condition somewhere in the loops that checks whether the optimizer has stepped or not for checkpointing. Maybe @carmocca remembers it. I can't remember if there was a good reason for this. |
Hi @awaelchli, thank you for your reply. Yes, I am aware that the following should happen in the def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self.compute_loss(batch)
self.manual_backward(loss)
opt.step() However, for def training_step(self, batch, batch_idx):
"""
Training step (the training loss needs to be returned).
Argument/s:
batch - mini-batch from the training set DataLoader.
batch_idx - batch idx of each example in the mini-batch.
Returns:
loss - training loss for the mini-batch.
"""
# Mini-batch of examples
images, labels = batch
# Get optimiser
opt = self.optimizers()
# First forward-backward pass for SAM
enable_running_stats(self) # https://github.com/davda54/sam/issues/30#issuecomment-909712587
y_hat = self(images)
loss_1 = self.loss(y_hat['logits'], labels)
with self.trainer.model.no_sync(): # https://github.com/davda54/sam/issues/38
self.manual_backward(loss_1)
opt.first_step(zero_grad=True)
# Second forward-backward pass for SAM
disable_running_stats(self) # https://github.com/davda54/sam/issues/30#issuecomment-909712587
y_hat = self(images)
loss_2 = self.loss(y_hat['logits'], labels)
self.manual_backward(loss_2)
opt.second_step(zero_grad=True)
# Log loss
losses = {'train_loss_step_1': loss_1, 'train_loss_step_2': loss_2}
self.log_dict(losses, on_step=False, on_epoch=True, batch_size=images.size()[0])
return loss_1 and def disable_running_stats(model):
def _disable(module):
if isinstance(module, nn.BatchNorm2d):
module.backup_momentum = module.momentum
module.momentum = 0
model.apply(_disable)
def enable_running_stats(model):
def _enable(module):
if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
module.momentum = module.backup_momentum
model.apply(_enable) The optimiser is stepped in
Okay, I just tested my code from the initial comment with the addition of It would be great to find that condition so that something like SAM can be used with PTL. Thanks again for your help. |
It's been like this since #3852. When would you want to save a checkpoint if the weights have not been updated? |
Thank you both for your help. It saves checkpoints with SAM if you wrap the first forward-backward pass in a closure and pass it to the
Happy for this issue to be closed. |
🐛 Bug
Hi, I am using Sharpness-Aware Minimization (SAM) (as implemented here: https://github.com/davda54/sam). To implement SAM in PTL,
self.automatic_optimization = False
is needed in the init of the LightningModule, as the steps described here are performed twice as part of SAM.The issue that I am facing is that checkpoints are not saved in
lightning_logs
if I useself.automatic_optimization = False
in the init of the LightningModule.Please let me know if there is something basic that I am missing as I could not find anything regarding this issue here: https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html#manual-optimization.
To Reproduce
Expected behavior
If
self.automatic_optimization = False
, checkpoints are not saved inlightning_logs
. If it is removed, checkpoints are saved inlightning_logs
.Environment
Thanks in advance
cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @Borda @carmocca @justusschock
The text was updated successfully, but these errors were encountered: