-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WandB dropping items when logging LR or val_loss with accumulate_grad_batches > 1 #5469
Comments
Also if it helps: this occurs in every version down to (and including) 1.1.0, but does not occur in 1.0.8 |
Do you have any call to You will be able to bypass these issues with #5194 by using |
Any update on when #5194 will be merged in (sync_step=False functionality). Additionally, I'm not sure that For example, here's an example use-case. I set
Is there a way to fix this more cleanly? Pinging @SeanNaren because I've discussed this problem with him via the Lightning Slack. |
Any update on this? |
The wandb workaround has been merged #5194 and will be available in PL v1.2. |
Thanks, can confirm this is the case in the BoringModel notebook |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
This was fixed with #5931. |
I am still running into this problem, but weirdly enough only at accumulate_grad_batches > 8. Here is a minimal example: import os
import lightning.pytorch as L
from lightning.pytorch.loggers import WandbLogger
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
class SimpleDataset(Dataset):
def __init__(self, size=640, input_dim=10, output_dim=1):
self.X = torch.randn(size, input_dim)
self.y = torch.randn(size, output_dim)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
class SimpleModel(L.LightningModule):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.mse_loss(y_hat, y)
self.log(
"train_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
print(f"Step: {self.global_step}")
return loss
def configure_optimizers(self):
return Adam(self.parameters(), lr=1e-3)
# - if the number of grad acc. steps is < 16, it properly logged
# - all the train_loss_steps, but if it is = 16, it only logs one value, although it should log e.g. 40 values for 16 grad. acc. steps in this script
# - for 32 grad acc steps, it logs no train_loss_step at all
GRAD_ACCUMULATION_STEPS = 32
# this is irrespective of batch size, the problem persists at higher batch sizes too
BATCH_SIZE = 1
# Create the dataset and dataloader
dataset = SimpleDataset()
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
# Create the Lightning Module and Trainer
os.environ["WANDB_API_KEY"] = "YOUR_WANDB_API_KEY"
wandb_logger = WandbLogger(project="simple-example", entity="YOUR_WANDB_USERNAME")
model = SimpleModel()
trainer = L.Trainer(
max_epochs=2, logger=wandb_logger, accumulate_grad_batches=GRAD_ACCUMULATION_STEPS
)
# Train the model
trainer.fit(model, dataloader) In the example, if I set |
🐛 Bug
As you can see in the BoringModel, I get the following warnings from WandB logger:
This occurs when I add the following to the basic BordigModel:
self.log()
accumulate_grad_batches > 1
LearningRateMonitor
callback or log validation loss withself.log()
(or both, as in the colab)If any of these things is removed, the error doesn't occur.
The end result is that the LR metrics are note being logged at all. Worse than that, validation loss (and any other metrics that there would be!) do not get logged, making the logger useless.
The text was updated successfully, but these errors were encountered: