Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing cleanup after trainer.fit() and trainer.test() #4385

Closed
ananthsub opened this issue Oct 27, 2020 · 5 comments · Fixed by #8578
Closed

Missing cleanup after trainer.fit() and trainer.test() #4385

ananthsub opened this issue Oct 27, 2020 · 5 comments · Fixed by #8578
Assignees
Labels
bug Something isn't working duplicate This issue or pull request already exists
Milestone

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Oct 27, 2020

🐛 Bug

The Lightning trainer holds references to the LightningModule/DataModule after fit/test complete. This can leads to different behavior in calls likeL

Please reproduce using the BoringModel and post here

def test_x(tmpdir):
    # validation checks do not regularly since we don'tre-instantiate the trainer inside each loop
    for i in range(2):
      trainer = pl.Trainer(max_epochs=4, check_val_every_n_epoch=2, logger=False, checkpoint_callback=False)
      test_module = BoringModel()
      trainer.fit(test_module)


reuse_trainer = pl.Trainer(max_epochs=4, check_val_every_n_epoch=2, logger=False, checkpoint_callback=False)
def test_reuse(tmpdir):
    # validation checks do not run on the second loop since we don't re-instantiate the trainer inside each loop
    for i in range(2):
      test_module = BoringModel()
      reuse_trainer.fit(test_module)

To Reproduce

Expected behavior

The latter test_reuse should re-bind the model + hooks when calling fit again inside each loop

@ananthsub ananthsub added bug Something isn't working help wanted Open to be worked on labels Oct 27, 2020
@edenlightning edenlightning added this to the 1.0.x milestone Oct 29, 2020
@edenlightning edenlightning modified the milestones: 1.0.x, 1.0.7 Nov 10, 2020
@Borda Borda modified the milestones: 1.0.7, 1.0.x Nov 11, 2020
@edenlightning edenlightning removed this from the 1.0.x milestone Nov 13, 2020
@edenlightning edenlightning added the priority: 1 Medium priority task label Nov 17, 2020
@Borda Borda added the good first issue Good for newcomers label Dec 1, 2020
@jabertuhin
Copy link

Recently, I have faced this kind of issue.
For a kaggle competition, with same configuration(same trainer, datamodule and model) but with different variable names I was training my model in two separate notebook cells. But for two runs train and validation losses were different.
I tried to replicate these runs result by restarting the kernel and running all the cells.
Two versions of the notebook resulted same result, so it might be because of the issue mentioned here.

So, I would like to work on this issue.

@edenlightning
Copy link
Contributor

@jabertuhin would be great if you can contribute! Let us know if you need any help.

@jabertuhin
Copy link

@edenlightning
First, I thought at first the issue was with trainer. I went through the trainer code, and it seemed to trainer has no issue.
Then I was suspecting LightningDataModule(for my case), now I think it all goes down to Dataloader.

I set random seed in my notebook with this:

from pytorch_lightning import seed_everything
seed_everything(42)

Here is an example, I am creating two new datamodule objects, and getting different images for first batch.

Screenshot from 2021-03-26 20-58-51

And then I restarted the kernel and ran them again:

Screenshot from 2021-03-26 20-56-47

Is there any other better way to debug this issue? Or is it even an issue or expected behavior ?

@edenlightning edenlightning added this to the v1.3.x milestone Jul 1, 2021
@Borda Borda modified the milestones: v1.3.x, v1.4 Jul 6, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.3.x Jul 6, 2021
@edenlightning edenlightning modified the milestones: v1.3.x, v1.4.x Jul 13, 2021
@awaelchli
Copy link
Contributor

@jabertuhin what you are showing makes sense. If you want the same output in each cell, you also need to put the seeding in both cells.

@awaelchli awaelchli added the duplicate This issue or pull request already exists label Jul 23, 2021
@awaelchli
Copy link
Contributor

This issue is will be fixed by #5007, but need to wait until after 1.4 as some of the changes required will not be possible to make BC.

copy-paste full example reported by @ananthsub for me to verify agains #5007.

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def validation_epoch_end(self, outputs):
        print("validation finished")

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)


trainer_args = dict(
    max_epochs=4,
    check_val_every_n_epoch=2,
    logger=False,
    checkpoint_callback=False,
    progress_bar_refresh_rate=0,
    weights_summary=None,
    num_sanity_val_steps=0,
)


def run0():
    # validation checks do not regularly since we don'tre-instantiate the trainer inside each loop
    for i in range(2):
        print("iteration ", i)
        trainer = Trainer(**trainer_args)
        test_module = BoringModel()
        trainer.fit(test_module)


def run1():
    reuse_trainer = Trainer(**trainer_args)
    # validation checks do not run on the second loop since we don't re-instantiate the trainer inside each loop
    for i in range(2):
        print("iteration ", i)
        test_module = BoringModel()
        reuse_trainer.fit(test_module)


if __name__ == "__main__":
    # https://github.com/PyTorchLightning/pytorch-lightning/issues/4385

    # run0()
    run1()

@awaelchli awaelchli modified the milestones: v1.4.x, v1.5 Jul 23, 2021
@awaelchli awaelchli removed this from the v1.5 milestone Nov 4, 2021
@awaelchli awaelchli added this to the 1.5.x milestone Nov 4, 2021
@carmocca carmocca removed help wanted Open to be worked on good first issue Good for newcomers refactor priority: 1 Medium priority task labels Jan 12, 2022
@carmocca carmocca self-assigned this Jan 12, 2022
@carmocca carmocca linked a pull request Feb 3, 2022 that will close this issue
12 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working duplicate This issue or pull request already exists
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants