Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping skips batches when min_epochs not reached #6699

Closed
gunthergl opened this issue Mar 27, 2021 · 3 comments · Fixed by #6705
Closed

Early stopping skips batches when min_epochs not reached #6699

gunthergl opened this issue Mar 27, 2021 · 3 comments · Fixed by #6705
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task

Comments

@gunthergl
Copy link

gunthergl commented Mar 27, 2021

🐛 Bug

When EarlyStopping hits but min_epochs is not reached, it seems that only the first batch of an epoch is used.

TL;DR: In a setting of 100 samples, batchsize 25 (=4 batches), training_step() should always be called exactly 4 times. When EarlyStopping hit before min_epochs is reached, it is only called once.

Please reproduce using the BoringModel

To Reproduce

I used the original BoringModel template and changed:

  • num_samples = 100
  • batch_size=25 in train, val and test
  • Add print('train, batch=', batch_idx) to self.training_step()
  • Add EarlyStopping and min_epochs when initializing the trainer

Resulting in:
https://colab.research.google.com/drive/11tlIU9NusGPeXJLUKA52ECuhIXOCH_4k?usp=sharing

I added the relevant output when training as image here:
image

Expected behavior

In my BoringModel, self.training_step() must be called 4 times in each epoch as long as min_epochs is not reached. Otherwise I suspect that the remaining 3 batches were not used to update the model parameters.

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: False
    • pyTorch_version: 1.8.0+cu101
    • pytorch-lightning: 1.2.5
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.10
    • version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020

Additional context

  • I took a look into the sourcecode during my actual analysis and found that in pytorch_lightning/trainer/training_loop.py:TrainLoop.run_training_epoch(), self.trainer.should_stop is True (and breaks after the first batch) when EarlyStopping hit.

  • As a result, it seems that once EarlyStopping hit before min_epochs is reached, it always stops with min_epochs, even if the stopping-criteria is not meet then. (But I did not MWE that.)

  • PS: Is it intended that BoringModel.forward() is not called in the template but self.layer() instead?

@gunthergl gunthergl added bug Something isn't working help wanted Open to be worked on labels Mar 27, 2021
@awaelchli awaelchli self-assigned this Mar 28, 2021
@awaelchli
Copy link
Contributor

from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks import EarlyStopping
import torch
from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    train_forward_counter = 0

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def training_step(self, batch, batch_idx):
        print('train, batch=', batch_idx)
        output = self.layer(batch)
        loss = output.sum() * 0.0  # force minimal loss to trigger early stopping
        self.log("loss", loss)
        self.train_forward_counter += 1
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    model = BoringModel()
    num_samples = 100
    train = RandomDataset(32, num_samples)
    train = DataLoader(train, batch_size=25)

    early_stop_callback = EarlyStopping(monitor='loss', patience=1)
    trainer = Trainer(
        min_epochs=15,
        max_epochs=500,
        callbacks=[early_stop_callback]
    )
    trainer.fit(model, train)
    print(len(train))
    print(model.train_forward_counter)


if __name__ == '__main__':
    run()

@awaelchli
Copy link
Contributor

@gunthergl Does this fix your issue: #6705 ?

@tchaton tchaton added the priority: 1 Medium priority task label Mar 29, 2021
@gunthergl
Copy link
Author

Hi @awaelchli, yes this seems to fix it. Thank you for the fast fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants