-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpointing and Early Stopping fail to work correctly when increasing number of train batches (in some cases) #3789
Comments
Okay, so tried a few more things:
This leads me to believe the real issue is that the validiation methods are perhaps not called properly in some cases. This might be beyond simply a logging or checkpoint problem. In that sense, might be related to #3666 where the logging and checkpointing seem to be affecting each other perhaps? |
WIP Update: So far I've managed to isolate the issue to this snippet in the training_loop, that indicates why the training loop behavior is affecting the val loop: I logged the value of
This has been run with More to come. |
@chiragraman it would be great to have such minimal example which we could also add to tests... |
@edenlightning (wrt to the Priority label) I've almost completed following the function calls in the library code, and haven't found anything off yet. Given that it's not really reproducible in the limited example, I am inclined to think that the problem lies in the iteration of the loader, probably in the sampling (in my user level code). I'll update shortly, but in my humble opinion the dev teams' effort might be better saved for other higher priority tasks. |
Figured out the root cause of this, and tldr; I don't think it is a bug in the library code. At best the library can help with throwing a warning to help with debugging user code because it's been a long, painful journey from symptom to cause. But I don't think that's the library's obligation as such. There might be possibility to improve the check for whether the evaluation step should be run, but the current issue isn't caused by a bug. The problem was that I was using a custom sampler to generate batches of sequences bucketed by length (read: https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284). The sampler was incorrectly returning the length of the dataset rather than the number of batches, which in this case was 270. Consequently, the iterator would return the last batch at batch_idx 270 but the evaluation check would fail because the I think issuing a warning when the iterator has ended at a batch_idx less than the expected number of training samples would help in this case. @Borda I think #1764 describes the problem as a bug in design rather than implementation. I think the code does what it's implemented to do from what I can see walking through it for the better part of a day; if checkpointing is intended to be done at the end of the epoch I think a few things would have to change other than |
Agree that it is design issue rather than implementation |
@Borda Forgot to address this, it got a little late. Now that I know what's wrong, I think I have an idea for a nice to have warning I described above, I'll also add an example to the tests and issue a PR. Would you prefer this after 1.0 stable is released? I'll probably need a week anyways with other priorities. |
🐛 Bug
( Preface: I created a complete minimal example for this bug report that unfortunately didn't end up reproducing the behavior, but I still think it might be useful to mention this nevertheless ).
The symptom is that when I leave everything else the same but increase the number of my training batches from 1000 to 5000, both checkpointing and early stopping completely fail to work correctly. As verified by creating a minimal example with a different simpler model it's not so much the number of batches but perhaps somehow related to the time it takes for an epoch to run, maybe. Here is a more detailed description:
Setup
In the LightningModule:
By construction, epoch 5 should be the best model, and early stopping should trigger on epoch 8.
Experiment setup:
Behavior
Run 1 - All correct.
Okay, so with that, if I run with anywhere between 10 to a 1000 training batches, things work perfectly:
The checkpoint updates correctly on disk during training, and the last one is correctly named
epoch=5-val_loss=0.90.ckpt
.Run 2 - Problematic.
If I increase the number of training batches to 5000, early stopping is triggered after the first 3 epochs, the checkpoints are not created live (only after early stopping has ended), and has the completely incorrect name
epoch=3-val_loss=40.00.ckpt
. Note that the epoch 3 val loss by construction should be 10, while it's picked up the epoch 0 loss.To Reproduce / Code Sample
I do have an isolated minimal code sample but unfortunately it works okay as expected, even with 15000 training batches. It's a much simpler model, since I'm using a variation of the LitModel, and I've tried to keep dimensions of tensors similar to my actual problem, so I don't know where the problem is right now. Here is the minimal code sample nevertheless:
https://gist.github.com/chiragraman/16b1a89787df0c517b8dfffae5c3d591
Expected behavior
The expected behavior in Run 2 above is to match the behavior in Run 1.
Environment
- GPU:
- Quadro P4000
- available: True
- version: 10.2
- numpy: 1.19.1
- pyTorch_debug: False
- pyTorch_version: 1.6.0
- pytorch-lightning: 0.9.1rc4
- tqdm: 4.48.2
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.5
- version: Support of different batch types #113-Ubuntu SMP Thu Jul 9 23:41:39 UTC 2020
Additional context
Also, minor side note: I'm getting a warning about
UserWarning: The validation_epoch_end should not return anything as of 9.1.to log, use self.log(...) or self.write(...) directly in the LightningModule
when I haven't implementedvalidation_epoch_end
at all, and am not returning anything fromvalidation_step
The text was updated successfully, but these errors were encountered: