Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointing and Early Stopping fail to work correctly when increasing number of train batches (in some cases) #3789

Closed
chiragraman opened this issue Oct 2, 2020 · 8 comments · Fixed by #3807
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@chiragraman
Copy link
Contributor

chiragraman commented Oct 2, 2020

🐛 Bug

( Preface: I created a complete minimal example for this bug report that unfortunately didn't end up reproducing the behavior, but I still think it might be useful to mention this nevertheless ).

The symptom is that when I leave everything else the same but increase the number of my training batches from 1000 to 5000, both checkpointing and early stopping completely fail to work correctly. As verified by creating a minimal example with a different simpler model it's not so much the number of batches but perhaps somehow related to the time it takes for an epoch to run, maybe. Here is a more detailed description:

Setup

In the LightningModule:

 def training_step(self, batch, _) -> Tensor:
        """ Perform a single step in the training loop """
        loss, nll = self.shared_step(batch, self.hparams.teacher_forcing)
        loss_with_reg = loss + self.reg(self.process)
        logs = {"loss_no_reg": loss, "loss_with_reg": loss_with_reg,
                    "nll": nll}

        self.log_dict(logs, on_epoch=True)
        return loss_with_reg

    def validation_step(self, batch: types.DataSplit, batch_idx) ->None:
        """ Perform an evaluation step """
        nll = torch.tensor(float(0)).to(batch.context.device)
        losses = [40, 20, 30, 10, 1, 0.9, 1, 1, 90, 100]
        loss = torch.tensor(float(losses[self.current_epoch])).to(batch.context.device)
        logs = {"val_loss": loss, "val_nll": nll}
        self.log_dict(logs)

By construction, epoch 5 should be the best model, and early stopping should trigger on epoch 8.

Experiment setup:

outroot = Path(args.out_dir)
logger = TestTubeLogger(save_dir=str(outroot / "logs"))

ckpt_filepath = "{}/{{epoch}}-{{val_loss:.2f}}".format(
        str(outroot / "logs" / "checkpoints"))
checkpoint_callback = ModelCheckpoint(
    filepath=ckpt_filepath, save_top_k=1, monitor="val_loss",
    verbose=True
)

early_stop = EarlyStopping(monitor="val_loss", verbose=True)

trainer = Trainer.from_argparse_args(
    args, logger=logger, checkpoint_callback=checkpoint_callback,
    early_stop_callback=early_stop
)

trainer.fit(model, datamodule=dm)

Behavior

Run 1 - All correct.
Okay, so with that, if I run with anywhere between 10 to a 1000 training batches, things work perfectly:

❯ python -m run.run_synthetic_social --gpus 1 ... --max_epochs 10  --limit_train_batches 10 --limit_val_batches 5
...
home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The validation_epoch_end should not return anything as of 9.1.to log, use self.log(...) or self.write(...) directly in the LightningModule
  warnings.warn(*args, **kwargs)
Epoch 0:  73%|██████████████████████        | 11/15 [00:00<00:00, 12.14it/s, loss=443490368.000, v_num=0]Epoch 0: val_loss reached 40.00000 (best 40.00000), saving model to /home/chirag/Projects/mingle/social-processes/artefacts/exp/dev_run/logs/checkpoints/epoch=0-val_loss=40.00.ckpt as top 1
Epoch 1:  80%|████████████████████████      | 12/15 [00:01<00:00, 11.90it/s, loss=372233728.000, v_num=0]Epoch 1: val_loss reached 20.00000 (best 20.00000), saving model to /home/chirag/Projects/mingle/social-processes/artefacts/exp/dev_run/logs/checkpoints/epoch=1-val_loss=20.00.ckpt as top 1
Epoch 2:  80%|████████████████████████      | 12/15 [00:01<00:00, 10.89it/s, loss=197302992.000, v_num=0]Epoch 2: val_loss was not in top 1                                         | 1/5 [00:00<00:00,  7.89it/s]
Epoch 3:  80%|████████████████████████      | 12/15 [00:01<00:00, 11.46it/s, loss=108065304.000, v_num=0]Epoch 3: val_loss reached 10.00000 (best 10.00000), saving model to /home/chirag/Projects/mingle/social-processes/artefacts/exp/dev_run/logs/checkpoints/epoch=3-val_loss=10.00.ckpt as top 1
Epoch 4:  80%|████████████████████████      | 12/15 [00:00<00:00, 12.58it/s, loss=104392560.000, v_num=0]Epoch 4: val_loss reached 1.00000 (best 1.00000), saving model to /home/chirag/Projects/mingle/social-processes/artefacts/exp/dev_run/logs/checkpoints/epoch=4-val_loss=1.00.ckpt as top 1
Epoch 5:  80%|████████████████████████▊      | 12/15 [00:00<00:00, 13.81it/s, loss=64182412.000, v_num=0]Epoch 5: val_loss reached 0.90000 (best 0.90000), saving model to /home/chirag/Projects/mingle/social-processes/artefacts/exp/dev_run/logs/checkpoints/epoch=5-val_loss=0.90.ckpt as top 1
Epoch 6:  80%|████████████████████████▊      | 12/15 [00:01<00:00, 11.82it/s, loss=65260504.000, v_num=0]Epoch 6: val_loss was not in top 1                                         | 1/5 [00:00<00:00,  6.81it/s]
Epoch 7:  80%|████████████████████████      | 12/15 [00:01<00:00, 11.47it/s, loss=105555992.000, v_num=0]Epoch 7: val_loss was not in top 1                                         | 1/5 [00:00<00:00,  7.17it/s]
Epoch 8:  80%|████████████████████████      | 12/15 [00:01<00:00, 11.64it/s, loss=113607824.000, v_num=0]Epoch 8: val_loss was not in top 1                                         | 1/5 [00:00<00:00,  6.88it/s]
Epoch 8: 100%|██████████████████████████████| 15/15 [00:01<00:00, 13.69it/s, loss=113607824.000, v_num=0Epoch 00009: early stopping triggered.                                                                    
Epoch 8: 100%|██████████████████████████████| 15/15 [00:01<00:00, 13.17it/s, loss=113607824.000, v_num=0]

The checkpoint updates correctly on disk during training, and the last one is correctly named epoch=5-val_loss=0.90.ckpt.

Run 2 - Problematic.
If I increase the number of training batches to 5000, early stopping is triggered after the first 3 epochs, the checkpoints are not created live (only after early stopping has ended), and has the completely incorrect name epoch=3-val_loss=40.00.ckpt. Note that the epoch 3 val loss by construction should be 10, while it's picked up the epoch 0 loss.

❯ python -m run.run_synthetic_social --gpus 1 ... --max_epochs 10  --limit_train_batches 5000 --limit_val_batches 5
...
home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The validation_epoch_end should not return anything as of 9.1.to log, use self.log(...) or self.write(...) directly in the LightningModule
  warnings.warn(*args, **kwargs)
Epoch 3:  20%|█████▉                       | 1026/5005 [01:22<05:18, 12.49it/s, loss=341889.875, v_num=0]Epoch 3: val_loss reached 40.00000 (best 40.00000), saving model to /home/chirag/Projects/mingle/social-processes/artefacts/exp/dev_run/logs/checkpoints/epoch=3-val_loss=40.00.ckpt as top 1
Epoch 00004: early stopping triggered.
Epoch 3:  20%|█████▉                       | 1026/5005 [01:22<05:19, 12.44it/s, loss=341889.875, v_num=0]

To Reproduce / Code Sample

I do have an isolated minimal code sample but unfortunately it works okay as expected, even with 15000 training batches. It's a much simpler model, since I'm using a variation of the LitModel, and I've tried to keep dimensions of tensors similar to my actual problem, so I don't know where the problem is right now. Here is the minimal code sample nevertheless:
https://gist.github.com/chiragraman/16b1a89787df0c517b8dfffae5c3d591

Expected behavior

The expected behavior in Run 2 above is to match the behavior in Run 1.

Environment

  • CUDA:
    - GPU:
    - Quadro P4000
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.19.1
    - pyTorch_debug: False
    - pyTorch_version: 1.6.0
    - pytorch-lightning: 0.9.1rc4
    - tqdm: 4.48.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.5
    - version: Support of different batch types #113-Ubuntu SMP Thu Jul 9 23:41:39 UTC 2020

Additional context

Also, minor side note: I'm getting a warning about UserWarning: The validation_epoch_end should not return anything as of 9.1.to log, use self.log(...) or self.write(...) directly in the LightningModule when I haven't implemented validation_epoch_end at all, and am not returning anything from validation_step

@chiragraman chiragraman added bug Something isn't working help wanted Open to be worked on labels Oct 2, 2020
@edenlightning
Copy link
Contributor

@williamFalcon

@chiragraman
Copy link
Contributor Author

Okay, so tried a few more things:

  1. Changed the Checkpoint callback to save every model using save_top_k=-1. With Run 1 conditions, works fine. With Run 2 conditions, doesn't save anything except for what seems like a value from the sanity check validation. But similar to what's reported above, filename has the last epoch and validation score from 0. Tensorboard does have an entry for the validation loss, but this remains at 0 and does not update.

  2. I went back to the stable release through conda and changed back my code to using the ResultObject. The behavior is the same with both runs. Tensorboard does not have any entries for the validation loss in Run 2 at all.


This leads me to believe the real issue is that the validiation methods are perhaps not called properly in some cases. This might be beyond simply a logging or checkpoint problem. In that sense, might be related to #3666 where the logging and checkpointing seem to be affecting each other perhaps?

@chiragraman
Copy link
Contributor Author

WIP Update:

So far I've managed to isolate the issue to this snippet in the training_loop, that indicates why the training loop behavior is affecting the val loop:

https://github.com/PyTorchLightning/pytorch-lightning/blob/ac2b0f0f066967b5896a73591ce7aa25dfb75306/pytorch_lightning/trainer/training_loop.py#L536-L538

I logged the value of is_last_batch and batch_idx and found this weird behavior:

Epoch 0:  53%|███████████████████████████████████████████▌                                      | 268/505 [00:18<00:16, 14.17it/s, loss=18235956.000, v_num=1]IS LAST BATCH:  False batch_idx:  267
Epoch 0:  53%|███████████████████████████████████████████▋                                      | 269/505 [00:18<00:16, 14.18it/s, loss=18057206.000, v_num=1]IS LAST BATCH:  False batch_idx:  268
Epoch 0:  53%|███████████████████████████████████████████▊                                      | 270/505 [00:19<00:16, 14.14it/s, loss=16772282.000, v_num=1]IS LAST BATCH:  True batch_idx:  269
Epoch 1:   0%|▏                                                                                   | 1/505 [00:00<03:03,  2.75it/s, loss=16925516.000, v_num=1]

This has been run with limit_train_batches = 500 as indicated in the progress bar denominator. Yet is_last_batch is TRUE for batch_idx 269. Also, it isnt' invoking validation_step at all consequently, explaining all the other errors above.

More to come.

@edenlightning edenlightning added this to the 0.9.x milestone Oct 2, 2020
@edenlightning edenlightning added the priority: 0 High priority task label Oct 2, 2020
@Borda
Copy link
Member

Borda commented Oct 2, 2020

@chiragraman it would be great to have such minimal example which we could also add to tests...
also, it seems to be similar as #1764

@chiragraman
Copy link
Contributor Author

chiragraman commented Oct 2, 2020

@edenlightning (wrt to the Priority label) I've almost completed following the function calls in the library code, and haven't found anything off yet. Given that it's not really reproducible in the limited example, I am inclined to think that the problem lies in the iteration of the loader, probably in the sampling (in my user level code). I'll update shortly, but in my humble opinion the dev teams' effort might be better saved for other higher priority tasks.

@chiragraman
Copy link
Contributor Author

Figured out the root cause of this, and tldr; I don't think it is a bug in the library code. At best the library can help with throwing a warning to help with debugging user code because it's been a long, painful journey from symptom to cause. But I don't think that's the library's obligation as such. There might be possibility to improve the check for whether the evaluation step should be run, but the current issue isn't caused by a bug.


The problem was that I was using a custom sampler to generate batches of sequences bucketed by length (read: https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284). The sampler was incorrectly returning the length of the dataset rather than the number of batches, which in this case was 270. Consequently, the iterator would return the last batch at batch_idx 270 but the evaluation check would fail because the num_training_samples was reflecting the number of sequences in the dataset rather than number of batches in the loader.

I think issuing a warning when the iterator has ended at a batch_idx less than the expected number of training samples would help in this case.


@Borda I think #1764 describes the problem as a bug in design rather than implementation. I think the code does what it's implemented to do from what I can see walking through it for the better part of a day; if checkpointing is intended to be done at the end of the epoch I think a few things would have to change other than should_check_val_fx. I'm also not sure #3807 would have much effect on the behavior described in this issue, I think the best thing to do, if anything at all, to address the present issue is to check if the training loop is terminating before the expected number of samples have been received and throw a warning if so.

@Borda
Copy link
Member

Borda commented Oct 3, 2020

think #1764 describes the problem as a bug in design rather than implementation

Agree that it is design issue rather than implementation

@chiragraman
Copy link
Contributor Author

@chiragraman it would be great to have such minimal example which we could also add to tests...
also, it seems to be similar as #1764

@Borda Forgot to address this, it got a little late. Now that I know what's wrong, I think I have an idea for a nice to have warning I described above, I'll also add an example to the tests and issue a PR. Would you prefer this after 1.0 stable is released? I'll probably need a week anyways with other priorities.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants