Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Duplicate epochs when calling .fit() twice #5007

Closed
carmocca opened this issue Dec 7, 2020 · 11 comments · Fixed by #8578
Closed

Duplicate epochs when calling .fit() twice #5007

carmocca opened this issue Dec 7, 2020 · 11 comments · Fixed by #8578
Assignees
Labels
breaking change Includes a breaking change bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@carmocca
Copy link
Contributor

carmocca commented Dec 7, 2020

🐛 Bug

To Reproduce

def test_bug(tmpdir):
    epochs = []
    
    class TestModel(BoringModel):
        def on_epoch_end(self):
            epochs.append(self.current_epoch)
    
    trainer = Trainer(
        max_epochs=2,
        limit_train_batches=1,
        limit_val_batches=1,
        default_root_dir=tmpdir,
        checkpoint_callback=False,
        logger=False,
        weights_summary=None,
        progress_bar_refresh_rate=0,
    )
    trainer.fit(TestModel())
    trainer.max_epochs=4
    trainer.fit(TestModel())

    assert epochs == list(range(4))
    # AssertionError [0, 1, 1, 2, 3] != [0, 1, 2, 3]

Expected behavior

Assertion does not fail

Environment

Current master

cc @tchaton @Borda

@carmocca carmocca added bug Something isn't working help wanted Open to be worked on labels Dec 7, 2020
@carmocca
Copy link
Contributor Author

carmocca commented Dec 8, 2020

The epoch number is generated here:

https://github.com/PyTorchLightning/pytorch-lightning/blob/239347435029c0a02b305201ebbfa39d62746ca8/pytorch_lightning/trainer/trainer.py#L511

which assumes that self.current_epoch has not run yet. This assumption is not correct when fit is run twice because the epoch number is not increased after training ends.

When using Trainer(resume_from_checkpoint=...) this issue does not appear due to this piece:

https://github.com/PyTorchLightning/pytorch-lightning/blob/239347435029c0a02b305201ebbfa39d62746ca8/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L273-L279

So the solution would be to increase the epoch number at the end of training_loop.on_train_end()? (This would break backwards compatibility)

@ananthsub
Copy link
Contributor

@carmocca or should we reset these parameters in the trainer teardown?

@carmocca
Copy link
Contributor Author

I'd say it's more natural to do it on_train_end because it has nothing to do with test and teardown is for fit or test

@pierresegonne
Copy link

Any opinion on what would be the recommended way of resetting the current_epoch to call fit twice?

Is something along those lines

trainer.fit(model, datamodule)
model.trainer.current_epoch = 0
trainer.fit(model, datamodule)

safe ?

@carmocca
Copy link
Contributor Author

Yes, it should be safe.

@edenlightning
Copy link
Contributor

@carmocca what is left TODO here?

@carmocca
Copy link
Contributor Author

carmocca commented Feb 17, 2021

Everything, the bug is not fixed 🙂

There is a reproduction test at the top. We just need to make our minds about the best solution. Context here: #5007 (comment)

@carmocca
Copy link
Contributor Author

carmocca commented Jul 6, 2021

Status update: WIP - tackling other related issues first. Need this for fault-tolerance

@edenlightning edenlightning modified the milestones: v1.3.x, V1.4.X, v1.4.x Jul 6, 2021
@carmocca
Copy link
Contributor Author

Status update: Blocked by merging #8477 and enabling restoring the ckpt progress tracking state by default.

@Borda Borda removed the with code label Aug 19, 2021
@awaelchli awaelchli modified the milestones: v1.5, 1.5.x Nov 4, 2021
@tchaton tchaton added priority: 0 High priority task and removed priority: 1 Medium priority task labels Nov 29, 2021
@fschiffers
Copy link

Any opinion on what would be the recommended way of resetting the current_epoch to call fit twice?

Is something along those lines

trainer.fit(model, datamodule)
model.trainer.current_epoch = 0
trainer.fit(model, datamodule)

safe ?

I think the current_epoch can no longer be set in the trainer but must be set in the fit_loop itself.

@carmocca
Copy link
Contributor Author

Correct. You'll need to do trainer.fit_loop.current_epoch = 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking change Includes a breaking change bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
9 participants