Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest Lightning does not support multiple callbacks that stop #6194

Closed
jlperla opened this issue Feb 25, 2021 · 1 comment · Fixed by #6197
Closed

Latest Lightning does not support multiple callbacks that stop #6194

jlperla opened this issue Feb 25, 2021 · 1 comment · Fixed by #6197
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@jlperla
Copy link

jlperla commented Feb 25, 2021

🐛 Bug

In the latest version of lightning, you do not seem to be able to have multiple callbacks which can stop.

Please reproduce using the BoringModel

  1. If you have mulitple callbacks which can do early stopping, only the last one can be active.
  2. Create a callback with early stopping, MyStoppingCallback(). Add it, then EarlyStoppingCallback() to the callbacks argument of the trainer, e.g. callbacks = [MyStoppingCallback(), EarlyStoppingCallback('val_loss')]
  • The callback is triggered and calculates that it needs to stop, but it ontinues training
  • On the other hand, if you change the order (e.g. callbacks = [EarlyStoppingCallback('val_loss'),MyStoppingCallback()] it will be stop with MyStoppingCallback but probably doesn't triggle the EarlyStoppingCallback.
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# --------------------------------------------
# --------------------------------------------
# --------------------------------------------
# USE THIS MODEL TO REPRODUCE A BUG YOU REPORT
# --------------------------------------------
# --------------------------------------------
# --------------------------------------------
import os

import torch
from torch.utils.data import Dataset

from pl_examples import cli_lightning_logo
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import Callback

class RandomDataset(Dataset):
    """
    >>> RandomDataset(size=10, length=20)  # doctest: +ELLIPSIS
    <...bug_report_model.RandomDataset object at ...>
    """

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    """
    >>> BoringModel()  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    BoringModel(
      (layer): Linear(...)
    )
    """

    def __init__(self):
        """
        Testing PL Module

        Use as follows:
        - subclass
        - modify the behavior for what you want

        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing

        or:

        model = BaseTestModel()
        model.training_epoch_end = None

        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def step(self, x):
        x = self.layer(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log('val_loss', loss)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


#  NOTE: If you are using a cmd line to run your script,
#  provide the cmd line as below.
#  opt = "--max_epochs 1 --limit_train_batches 1".split(" ")
#  parser = ArgumentParser()
#  args = parser.parse_args(opt)

class EarlyStoppingExample(Callback):
    def on_validation_end(self, trainer, pl_module):
        if trainer.current_epoch > 5:
            should_stop = True
        else:
            should_stop = False

        if bool(should_stop):
            print("\nSTOPPING!!!!!!!!!!!!!!!!!!!!\n")
            self.stopped_epoch = trainer.current_epoch
            trainer.should_stop = True

        # stop every ddp process if any world process decides to stop
        should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop)
        trainer.should_stop = should_stop

def test_run():

    class TestModel(BoringModel):

        def on_train_epoch_start(self) -> None:
            pass

    # fake data
    train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    test_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

    # model

    early_stopping = EarlyStopping('val_loss', patience=50)

    model = TestModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=100,
        weights_summary=None,
        callbacks=[
            EarlyStoppingExample(),
            early_stopping,
            ]
    )
    trainer.fit(model, train_data, val_data)
    trainer.test(test_dataloaders=test_data)


if __name__ == '__main__':
    #cli_lightning_logo()
    test_run()

To Reproduce

Use following BoringModel and post here

Expected behavior

  • PyTorch Version (e.g., 1.0): 1.7
  • OS (e.g., Linux): Windows
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.7.4
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:
@jlperla jlperla added bug Something isn't working help wanted Open to be worked on labels Feb 25, 2021
@SeanNaren SeanNaren added the priority: 0 High priority task label Feb 25, 2021
@SeanNaren SeanNaren self-assigned this Feb 25, 2021
@SeanNaren
Copy link
Contributor

hey @jlperla thanks so much for the issue! I think this PR should fix it, and I've added a test to make sure we don't break this again: #6197

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants