Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wandb logger problem with on_step log on validation #4980

Closed
andreaRapuzzi opened this issue Dec 4, 2020 · 10 comments · Fixed by #5351
Closed

wandb logger problem with on_step log on validation #4980

andreaRapuzzi opened this issue Dec 4, 2020 · 10 comments · Fixed by #5351
Labels
3rd party Related to a 3rd-party bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task question Further information is requested

Comments

@andreaRapuzzi
Copy link

🐛 Bug

When logging on the validation_step with on_step=True and on_epoch=False the following happens:

  • wandb warnings are generated to alert about a step numbering problem (probably confusing the validation step number which seems cyclical with the overall step which is always increasing)

image

  • wandb charts for training (by step) is shrunk on the x dimension (like the number of steps for the whole training were less). We tested 2 training runs: the first (blue in the image below) with on_step=False and on_epoch=True on validation_step, the second with on_step=True and on_epoch=False (red in the image below). As you can see the training chart is affected by this:

image

  • an error is issued at the end of the second training run:

image

  • two new (unrequested) panels appear at the top to the wandb project (this is the weirdest of the lot :-))

image

Please reproduce using the colab link at the top of this article

To Reproduce

Just change the validation_step logging like this:

def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = F.nll_loss(logits, y)

    # validation metrics
    preds = torch.argmax(logits, dim=1)
    acc = accuracy(preds, y)
    self.log('val_loss', loss, on_step=True, on_epoch=False, prog_bar=True)
    self.log('val_acc', acc, on_step=True, on_epoch=False, prog_bar=True)
    return loss
@andreaRapuzzi andreaRapuzzi added bug Something isn't working help wanted Open to be worked on labels Dec 4, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Dec 4, 2020

Hi! thanks for your contribution!, great first issue!

@awaelchli
Copy link
Contributor

awaelchli commented Dec 6, 2020

This is not surprising. What exactly are you trying to log in each step of validation?
It is not meaningful to log per batch in validation as there is no real progression and the global step does not increase.
If you need a validation accuracy, certainly you should average it over the whole validation set which is what self.log does by default when you set epoch=True.
Hope this answer helps.

@awaelchli awaelchli added information needed question Further information is requested labels Dec 6, 2020
@andreaRapuzzi
Copy link
Author

Hi @awaelchli, thanks for the prompt answer!
While I agree that the per epoch validation info has the most value, still there's value in logging validation step info to monitor how 'noisy' (variance) is a metric inside a single epoch.
So my initial validation logging had on_step=True and on_epoch=True which had similar problems. To simplify the problem reporting I described a simpler config (on_step=True, on_epoch=False).
Logging validation just on the epoch level is a sensible and useful default, but it should not be the only way (or otherwise the library interface should constraint this behaviour).
Thanks a lot, again!!

@edenlightning edenlightning added 3rd party Related to a 3rd-party priority: 1 Medium priority task and removed 3rd party Related to a 3rd-party information needed labels Dec 7, 2020
@borisdayma
Copy link
Contributor

My suggestion in this case would be to gather your validation metrics in an array and log it as an histogram at the end.

@andreaRapuzzi
Copy link
Author

Thanks @borisdayma !!

@collinmccarthy
Copy link

When training ImageNet, the validation epoch takes a long time. For this reason I would have liked to output a per-batch validation loss , as well as an aggregated per-epoch loss for early stopping and checkpointing. It wasn't obvious to me in the docs that this would be handled any differently than during training until I ran into the same issues as @andreaRapuzzi.

However I understand why this is tricky and why different people would want it handled differently. It seems like trainer.global_step only considers training batches when incrementing this counter (makes sense), and in general the validation loss usually (if not always) makes the most sense as a single statistic aggregated over the entire validation set. Nonetheless, tracking a per-batch val_loss for one chart and another val_avg_loss for a second chart (and early stopping / checkpointing), would be a useful capability for some.

It doesn't seem like it would require a lot of code to change (namely EvaluationLoop.__log_result_step_metrics()) but from a design perspective it may not be so simple to come up with the best way to handle the different use cases.

@borisdayma
Copy link
Contributor

@collinmccarthy
With wandb, you could just use wandb.log({'val_loss':val_loss, 'my_x_axis': batch + step*n_batch}, commit=False)
However you would not see the results yet as the values would not be sent to the server.

An alternative (a bit hacky) would be to consider how often you log during training.
Let's say you only log every 500 steps, then you have 499 steps to log your validation metrics (so you could aggregate your values and call if batch_idx % 500 == 0: wandb.log({'temp_val_loss': avg(my_losses))
You can even create your custom x-axis and use it in your panel for that specific chart.

@rohitgr7 rohitgr7 linked a pull request Jan 6, 2021 that will close this issue
12 tasks
@stale
Copy link

stale bot commented Jan 17, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jan 17, 2021
@awaelchli awaelchli removed the won't fix This will not be worked on label Jan 17, 2021
@stale
Copy link

stale bot commented Feb 16, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 16, 2021
@awaelchli awaelchli removed the won't fix This will not be worked on label Feb 17, 2021
@borisdayma
Copy link
Contributor

This was fixed with #5931.
You can now just log at any time with self.log('my_metric', my_value) and you won't have any dropped value.
Just choose your x-axis appropriately in the UI (whether global_step or just the auto-incremented step).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants