-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unexpected Behaviour with Model Checkpointing and val_check_interval #1764
Comments
Hi! thanks for your contribution!, great first issue! |
@yakobyd mind send a PR with the test case? 🐰 |
This bug is just making me crazy. |
pytorch-lightning is great. I kindly advise that the dev team can just slow down implementing new features but first fix these frustrating bugs like this one. |
which version? and can you put up a colab that replicates this? |
We have still limited resources to Fixx all for particular 0.X release... also as to it still 0 based in stable 1 version it would be understood... 🐰 |
btw, we are weeks away from 1.0.0. some of these bugs are being caused by refactors that we’re making for that, but this needs to be tested to make sure we have no regressions in 1.0 |
Thank you for the response. I am using version 0.7.6. The code checkpoint_callback = ModelCheckpoint(
filepath=hparams.exp_home + '{epoch}-{val_loss:.2f}',
save_top_k=-1,
verbose=True,
monitor='val_loss',
mode='min',
prefix=hparams.exp_name
)
# configure trainer
trainer = Trainer(
# fast_dev_run=True,
gpus=hparams.gpus if hparams.gpus is not None and len(
hparams.gpus) > 0 else None,
log_gpu_memory='all',
accumulate_grad_batches=hparams.accu_grad,
default_root_dir=hparams.exp_home,
checkpoint_callback=checkpoint_callback,
val_check_interval=5000,
max_nb_epochs=hparams.max_nb_epochs,
) The code is expected to dump a checkpoint every 5000 steps. But currently it just saves at the 5000 step for each epoch, like
, same as @yakobyd . I am looking into this. |
The bug comes from here: If we set val_check_interval to save several checkpoints in an epoch, |
@magic282 can you check if the bug persists on master? |
@edenlightning @magic282 it is til there, and it is there on purpose... |
This is still happening in:
Relevant yaml config:
See how after 0.50505 none of the subsequent better losses are recorded:
|
🐛 Bug
Currently, if we set
val_check_interval
in the Trainer flags, model checkpointing happens in the middle of an epoch. We expect it to always happen at the end of an epoch, as explained below.To Reproduce
When running the code sample below, we are presented with the following output:
Here, we have
val_check_interval == 0.1
. As one can see, the checkpointing happens at the beginning of an epoch, after the first validation check. The 9 validation chekcs that follow do not trigger model checkpointing. We point to the places from which this behavior emerges in the section "Additional context". Moreover, we use the section "Expected behavior" to explain why model checkpointing should happen at the end of an epoch (at least, as the defualt behavior).Code sample
The
Model
is not important here, we simply chose a minimal one. Please focus on the trainer flags.Expected behavior
Assume the following scenario. Our model has
val_check_interval == 0.5
. We have a dataset with 100 samples that can have one of 2 balanced labels (50 "0
"s, 50 "1
"s). After the first epoch, our model has "seen" each label 50 times. Now, assume that model checkpointing is triggered for the second epoch. In the worst case (we shuffle the samples), all 50 train samples that we saw in this epoch (recall that checkpointing occurs mid-epoch, before seeing the remaining 50 samples of this epoch) had the label "0
". Namely, we save a model that saw 100 "0
"s and 50 "1
"s.Currently, in my research, we study biased sub-datasets and we rely on the fact that our original dataset has balanced labels. Consequently, we create biased and unbiased models whose performance is compared on each of the labels. The behavior described above does not ensure that we have a balanced model when training on a balanced dataset, hurting our research assumptions and affecting our use-case.
Some may argue that this issue represents an enhancement proposal rather than a bug. However, I believe that this issue should be classified as a "bug", for several reasons. Similar to the paragraph above, the described behavior also affects users who rely on "balanced features" or users trying to estimate "feature importance". Thus, affecting a wider range of researchers. In addition, I have not encountered this behavior in the documentation, and it may be overlooked by users who run their checkpoint callback with
verbose == False
.In summary, following the reasoning given above, I expect model checkpointing to always occur at the end of an epoch.
Additional context
I believe that I have located the source of this behavior:
https://github.com/PyTorchLightning/pytorch-lightning/blob/3a642601e84c3abf1f1b438f9acc932a1f150f7f/pytorch_lightning/trainer/training_loop.py#L436-L453
For example, if
val_check_interval == 0.1
, the variableshould_check_val
is True 10 times during an epoch andself.call_checkpoint_callback()
is called.Then, the following if-statement fails at the first validation run of each epoch and saves the checkpoint. Afterwards, the 9 validation runs that follow, enter this if-statement and skip saving the model.
https://github.com/PyTorchLightning/pytorch-lightning/blob/25bbd059df68abc1b0ffa77ad2480af183d61b05/pytorch_lightning/callbacks/model_checkpoint.py#L210-L212
(As a side note, removing this line will result in a behavior that was requested in #1758 , but obviously, it belongs to a completely different discussion.)
Environment
1.4.0
conda
0.7.5
3.6.10
The text was updated successfully, but these errors were encountered: