Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected Behaviour with Model Checkpointing and val_check_interval #1764

Closed
yakobyd opened this issue May 9, 2020 · 12 comments · Fixed by #3807 or #3852
Closed

Unexpected Behaviour with Model Checkpointing and val_check_interval #1764

yakobyd opened this issue May 9, 2020 · 12 comments · Fixed by #3807 or #3852
Assignees
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@yakobyd
Copy link

yakobyd commented May 9, 2020

🐛 Bug

Currently, if we set val_check_interval in the Trainer flags, model checkpointing happens in the middle of an epoch. We expect it to always happen at the end of an epoch, as explained below.

To Reproduce

When running the code sample below, we are presented with the following output:

Epoch 1:  10%|█████▊                                                    | 500/5005 [00:02<00:22, 200.49it/s, loss=1.331]
INFO:lightning:                                                                                                         
Epoch 00000: val_loss reached 1.39875 (best 1.39875), saving model to checkpoints/_ckpt_epoch_0.ckpt as top 1
Epoch 2:  10%|█████▊                                                    | 500/5005 [00:02<00:22, 202.59it/s, loss=1.279]
INFO:lightning:                                                                                                         
Epoch 00001: val_loss reached 1.34045 (best 1.34045), saving model to checkpoints/_ckpt_epoch_1.ckpt as top 1
Epoch 3:  10%|█████▊                                                    | 500/5005 [00:02<00:21, 207.04it/s, loss=1.275]
INFO:lightning:                                                                                                         
Epoch 00002: val_loss reached 1.32092 (best 1.32092), saving model to checkpoints/_ckpt_epoch_2.ckpt as top 1
Epoch 4:  10%|█████▊                                                    | 500/5005 [00:02<00:26, 170.37it/s, loss=1.281]
INFO:lightning:                                                                                                         
Epoch 00003: val_loss  was not in top 1
Epoch 5:  10%|█████▊                                                    | 500/5005 [00:02<00:21, 206.15it/s, loss=1.285]
INFO:lightning:                                                                                                         
Epoch 00004: val_loss  was not in top 1

Here, we have val_check_interval == 0.1. As one can see, the checkpointing happens at the beginning of an epoch, after the first validation check. The 9 validation chekcs that follow do not trigger model checkpointing. We point to the places from which this behavior emerges in the section "Additional context". Moreover, we use the section "Expected behavior" to explain why model checkpointing should happen at the end of an epoch (at least, as the defualt behavior).

Code sample

The Model is not important here, we simply chose a minimal one. Please focus on the trainer flags.

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import pytorch_lightning as pl


class Model(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': val_loss_mean}

    def train_dataloader(self):
        return DataLoader(datasets.MNIST('mnist/', train=True, download=True,
                          transform=transforms.ToTensor()), batch_size=32)

    def val_dataloader(self):
        return DataLoader(datasets.MNIST('mnist/', train=False, download=True,
                          transform=transforms.ToTensor()), batch_size=32)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


if __name__ == '__main__':

    checkpoint_callback = pl.callbacks.ModelCheckpoint('checkpoints/', verbose=True)

    trainer = pl.Trainer(gpus=[0], val_check_interval=0.1,
                         checkpoint_callback=checkpoint_callback)
    model = Model()

    trainer.fit(model)

Expected behavior

Assume the following scenario. Our model has val_check_interval == 0.5. We have a dataset with 100 samples that can have one of 2 balanced labels (50 "0"s, 50 "1"s). After the first epoch, our model has "seen" each label 50 times. Now, assume that model checkpointing is triggered for the second epoch. In the worst case (we shuffle the samples), all 50 train samples that we saw in this epoch (recall that checkpointing occurs mid-epoch, before seeing the remaining 50 samples of this epoch) had the label "0". Namely, we save a model that saw 100 "0"s and 50 "1"s.

Currently, in my research, we study biased sub-datasets and we rely on the fact that our original dataset has balanced labels. Consequently, we create biased and unbiased models whose performance is compared on each of the labels. The behavior described above does not ensure that we have a balanced model when training on a balanced dataset, hurting our research assumptions and affecting our use-case.

Some may argue that this issue represents an enhancement proposal rather than a bug. However, I believe that this issue should be classified as a "bug", for several reasons. Similar to the paragraph above, the described behavior also affects users who rely on "balanced features" or users trying to estimate "feature importance". Thus, affecting a wider range of researchers. In addition, I have not encountered this behavior in the documentation, and it may be overlooked by users who run their checkpoint callback with verbose == False.

In summary, following the reasoning given above, I expect model checkpointing to always occur at the end of an epoch.

Additional context

I believe that I have located the source of this behavior:

https://github.com/PyTorchLightning/pytorch-lightning/blob/3a642601e84c3abf1f1b438f9acc932a1f150f7f/pytorch_lightning/trainer/training_loop.py#L436-L453

For example, if val_check_interval == 0.1, the variable should_check_val is True 10 times during an epoch and self.call_checkpoint_callback() is called.

Then, the following if-statement fails at the first validation run of each epoch and saves the checkpoint. Afterwards, the 9 validation runs that follow, enter this if-statement and skip saving the model.

https://github.com/PyTorchLightning/pytorch-lightning/blob/25bbd059df68abc1b0ffa77ad2480af183d61b05/pytorch_lightning/callbacks/model_checkpoint.py#L210-L212

(As a side note, removing this line will result in a behavior that was requested in #1758 , but obviously, it belongs to a completely different discussion.)

Environment

  • OS: Linux
  • PyTorch Version: 1.4.0
  • How you installed PyTorch: conda
  • PyTorch Lightning Version: 0.7.5
  • Python version: 3.6.10
@yakobyd yakobyd added bug Something isn't working help wanted Open to be worked on labels May 9, 2020
@github-actions
Copy link
Contributor

github-actions bot commented May 9, 2020

Hi! thanks for your contribution!, great first issue!

@Borda
Copy link
Member

Borda commented Jun 5, 2020

@magic282
Copy link

magic282 commented Jul 2, 2020

This bug is just making me crazy.

@magic282
Copy link

magic282 commented Jul 2, 2020

pytorch-lightning is great. I kindly advise that the dev team can just slow down implementing new features but first fix these frustrating bugs like this one.

@williamFalcon
Copy link
Contributor

which version? and can you put up a colab that replicates this?

@Borda
Copy link
Member

Borda commented Jul 2, 2020

pytorch-lightning is great. I kindly advise that the dev team can just slow down implementing new features but first fix these frustrating bugs like this one.

We have still limited resources to Fixx all for particular 0.X release... also as to it still 0 based in stable 1 version it would be understood... 🐰
@magic282 mind take it over and submit a fix for this issue?

@williamFalcon
Copy link
Contributor

btw, we are weeks away from 1.0.0. some of these bugs are being caused by refactors that we’re making for that, but this needs to be tested to make sure we have no regressions in 1.0

@magic282
Copy link

magic282 commented Jul 2, 2020

Thank you for the response. I am using version 0.7.6.

The code

        checkpoint_callback = ModelCheckpoint(
            filepath=hparams.exp_home + '{epoch}-{val_loss:.2f}',
            save_top_k=-1,
            verbose=True,
            monitor='val_loss',
            mode='min',
            prefix=hparams.exp_name
        )

        # configure trainer
        trainer = Trainer(
            # fast_dev_run=True,
            gpus=hparams.gpus if hparams.gpus is not None and len(
                hparams.gpus) > 0 else None,
            log_gpu_memory='all',
            accumulate_grad_batches=hparams.accu_grad,
            default_root_dir=hparams.exp_home,
            checkpoint_callback=checkpoint_callback,
            val_check_interval=5000,
            max_nb_epochs=hparams.max_nb_epochs,
        )

The code is expected to dump a checkpoint every 5000 steps. But currently it just saves at the 5000 step for each epoch, like

expepoch=0-val_loss=0.16.ckpt 
expepoch=1-val_loss=0.13.ckpt

, same as @yakobyd . I am looking into this.

@magic282
Copy link

magic282 commented Aug 3, 2020

The bug comes from here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/471f2b80af9a9aeb11d5092ddbc08faef7901552/pytorch_lightning/callbacks/model_checkpoint.py#L299

If we set val_check_interval to save several checkpoints in an epoch, epoch and epoch_last_check will always be the same. So the function will return. One workaround is to set period to a <= 0 number. I guess we can set period in
https://github.com/PyTorchLightning/pytorch-lightning/blob/471f2b80af9a9aeb11d5092ddbc08faef7901552/pytorch_lightning/callbacks/model_checkpoint.py#L237
if the user does not set period manually so that val_check_interval and period can match.

@edenlightning
Copy link
Contributor

@magic282 can you check if the bug persists on master?

@Borda
Copy link
Member

Borda commented Oct 2, 2020

@jamarju
Copy link

jamarju commented Jun 28, 2021

This is still happening in:

>>> pl.__version__
'1.3.7post0'

Relevant yaml config:

trainer:
  val_check_interval: 0.2
  callbacks:
    - class_path: pytorch_lightning.callbacks.ModelCheckpoint
      init_args:
        monitor: 'val_loss'
        verbose: true
        mode: 'min'
        filename: '04_{epoch}_{val_loss:.6f}'
        save_weights_only: true

See how after 0.50505 none of the subsequent better losses are recorded:

poch 0:  20%|████████▍                                 | 22/109 [00:06<00:27,  3.19it/s, loss=1.05, v_num=88, val_loss=1.030, lr_trf=7.38e-7, lr_clf=7.38e-7Epoch 0, global step 17: val_loss reached 1.03065 (best 1.03065), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=0_val_loss=1.030650.ckpt" as top 1
Epoch 0:  40%|████████████████▌                        | 44/109 [00:14<00:20,  3.13it/s, loss=0.984, v_num=88, val_loss=0.961, lr_trf=1.78e-6, lr_clf=1.78e-6Epoch 0, global step 35: val_loss reached 0.96075 (best 0.96075), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=0_val_loss=0.960752.ckpt" as top 1
Epoch 0:  61%|████████████████████████▊                | 66/109 [00:21<00:13,  3.10it/s, loss=0.954, v_num=88, val_loss=0.893, lr_trf=3.36e-6, lr_clf=3.36e-6Epoch 0, global step 53: val_loss reached 0.89280 (best 0.89280), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=0_val_loss=0.892801.ckpt" as top 1
Epoch 0:  81%|█████████████████████████████████        | 88/109 [00:28<00:06,  3.09it/s, loss=0.724, v_num=88, val_loss=0.694, lr_trf=5.23e-6, lr_clf=5.23e-6Epoch 0, global step 71: val_loss reached 0.69430 (best 0.69430), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=0_val_loss=0.694305.ckpt" as top 1
Epoch 0: 100%|████████████████████████████████████████| 109/109 [00:35<00:00,  3.06it/s, loss=0.597, v_num=88, val_loss=0.621, lr_trf=7.09e-6, lr_clf=7.09e-6Epoch 0, global step 89: val_loss reached 0.62113 (best 0.62113), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=0_val_loss=0.621130.ckpt" as top 1
Epoch 1:  20%|████████▎                                | 22/109 [00:06<00:27,  3.20it/s, loss=0.537, v_num=88, val_loss=0.644, lr_trf=8.94e-6, lr_clf=8.94e-6Epoch 1, global step 111: val_loss was not in top 1
Epoch 1:  40%|████████████████▉                         | 44/109 [00:13<00:19,  3.29it/s, loss=0.52, v_num=88, val_loss=0.577, lr_trf=9.82e-6, lr_clf=9.82e-6Epoch 1, global step 129: val_loss reached 0.57692 (best 0.57692), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=1_val_loss=0.576923.ckpt" as top 1
Epoch 1:  61%|████████████████████████▊                | 66/109 [00:20<00:13,  3.20it/s, loss=0.524, v_num=88, val_loss=0.621, lr_trf=9.99e-6, lr_clf=9.99e-6Epoch 1, global step 147: val_loss was not in top 1
Epoch 1:  81%|█████████████████████████████████        | 88/109 [00:27<00:06,  3.25it/s, loss=0.562, v_num=88, val_loss=0.594, lr_trf=9.88e-6, lr_clf=9.88e-6Epoch 1, global step 165: val_loss was not in top 1
Epoch 1: 100%|████████████████████████████████████████| 109/109 [00:33<00:00,  3.24it/s, loss=0.455, v_num=88, val_loss=0.545, lr_trf=9.62e-6, lr_clf=9.62e-6Epoch 1, global step 183: val_loss reached 0.54479 (best 0.54479), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=1_val_loss=0.544790.ckpt" as top 1
Epoch 2:  20%|████████▎                                | 22/109 [00:06<00:27,  3.19it/s, loss=0.434, v_num=88, val_loss=0.592, lr_trf=9.13e-6, lr_clf=9.13e-6Epoch 2, global step 205: val_loss was not in top 1
Epoch 2:  40%|████████████████▌                        | 44/109 [00:13<00:19,  3.28it/s, loss=0.461, v_num=88, val_loss=0.531, lr_trf=8.59e-6, lr_clf=8.59e-6Epoch 2, global step 223: val_loss reached 0.53115 (best 0.53115), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=2_val_loss=0.531149.ckpt" as top 1
Epoch 2:  61%|████████████████████████▊                | 66/109 [00:20<00:13,  3.20it/s, loss=0.389, v_num=88, val_loss=0.524, lr_trf=7.95e-6, lr_clf=7.95e-6Epoch 2, global step 241: val_loss reached 0.52437 (best 0.52437), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=2_val_loss=0.524372.ckpt" as top 1
Epoch 2:  81%|█████████████████████████████████        | 88/109 [00:27<00:06,  3.17it/s, loss=0.392, v_num=88, val_loss=0.517, lr_trf=7.22e-6, lr_clf=7.22e-6Epoch 2, global step 259: val_loss was not in top 1
Epoch 2: 100%|████████████████████████████████████████| 109/109 [00:34<00:00,  3.18it/s, loss=0.437, v_num=88, val_loss=0.535, lr_trf=6.43e-6, lr_clf=6.43e-6Epoch 2, global step 277: val_loss was not in top 1
Epoch 3:  20%|████████▎                                | 22/109 [00:06<00:27,  3.19it/s, loss=0.317, v_num=88, val_loss=0.518, lr_trf=5.41e-6, lr_clf=5.41e-6Epoch 3, global step 299: val_loss reached 0.51775 (best 0.51775), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=3_val_loss=0.517754.ckpt" as top 1
Epoch 3:  40%|████████████████▉                         | 44/109 [00:14<00:20,  3.12it/s, loss=0.36, v_num=88, val_loss=0.505, lr_trf=4.56e-6, lr_clf=4.56e-6Epoch 3, global step 317: val_loss reached 0.50505 (best 0.50505), saving model to "/space/ml/commonlit/lightning_logs/version_88/checkpoints/04_epoch=3_val_loss=0.505046.ckpt" as top 1
Epoch 3:  61%|████████████████████████▊                | 66/109 [00:21<00:13,  3.10it/s, loss=0.313, v_num=88, val_loss=0.520, lr_trf=3.73e-6, lr_clf=3.73e-6Epoch 3, global step 335: val_loss was not in top 1
Epoch 3:  81%|█████████████████████████████████        | 88/109 [00:27<00:06,  3.17it/s, loss=0.315, v_num=88, val_loss=0.517, lr_trf=2.93e-6, lr_clf=2.93e-6Epoch 3, global step 353: val_loss was not in top 1
Epoch 3: 100%|████████████████████████████████████████| 109/109 [00:34<00:00,  3.18it/s, loss=0.336, v_num=88, val_loss=0.516, lr_trf=2.19e-6, lr_clf=2.19e-6Epoch 3, global step 371: val_loss was not in top 1
Epoch 4:  20%|████████▎                                | 22/109 [00:06<00:27,  3.19it/s, loss=0.272, v_num=88, val_loss=0.504, lr_trf=1.39e-6, lr_clf=1.39e-6Epoch 4, global step 393: val_loss was not in top 1
Epoch 4:  40%|█████████████████▎                         | 44/109 [00:13<00:19,  3.28it/s, loss=0.262, v_num=88, val_loss=0.503, lr_trf=8.6e-7, lr_clf=8.6e-7Epoch 4, global step 411: val_loss was not in top 1
Epoch 4:  61%|████████████████████████▊                | 66/109 [00:19<00:12,  3.31it/s, loss=0.287, v_num=88, val_loss=0.504, lr_trf=4.45e-7, lr_clf=4.45e-7Epoch 4, global step 429: val_loss was not in top 1
Epoch 4:  81%|█████████████████████████████████        | 88/109 [00:26<00:06,  3.33it/s, loss=0.267, v_num=88, val_loss=0.504, lr_trf=1.62e-7, lr_clf=1.62e-7Epoch 4, global step 447: val_loss was not in top 1
Epoch 4: 100%|████████████████████████████████████████| 109/109 [00:32<00:00,  3.31it/s, loss=0.291, v_num=88, val_loss=0.503, lr_trf=1.81e-8, lr_clf=1.81e-8Epoch 4, global step 465: val_loss was not in top 1
Epoch 4: 100%|████████████████████████████████████████| 109/109 [00:34<00:00,  3.18it/s, loss=0.283, v_num=88, val_loss=0.503, lr_trf=8.07e-9, lr_clf=8.07e-9]Epoch 4, global step 469: val_loss was not in top 1
Epoch 4: 100%|████████████████████████████████████████| 109/109 [00:34<00:00,  3.18it/s, loss=0.283, v_num=88, val_loss=0.503, lr_trf=8.07e-9, lr_clf=8.07e-9]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
6 participants