Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load callback states while testing. #5542

Open
rohitgr7 opened this issue Jan 16, 2021 · 21 comments
Open

Load callback states while testing. #5542

rohitgr7 opened this issue Jan 16, 2021 · 21 comments
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement help wanted Open to be worked on priority: 1 Medium priority task trainer: test trainer: validate
Milestone

Comments

@rohitgr7
Copy link
Contributor

rohitgr7 commented Jan 16, 2021

🚀 Feature

Load callback states while testing.

Motivation

#5161 (comment)

Pitch

Two possible API changes:

with an additional argument restore_states:

test(ckpt_path, restore_states=True/False)  # give an option whether to load states or not
test(model, ckpt_path, restore_states=True/False)  # same as above but will just load checkpoint states and not the model

# raise an error
test(ckpt_path=None, restore_states=True)

or without any additional argument:

test(ckpt_path)  # always load states
test(ckpt_path=None)  # don't load any states.
test(model, ckpt_path)  # reload checkpoint states only from ckpt_path

Alternatives

Alternatively, one can just reload checkpoints manually, call on_load_checkpoint for all the callbacks manually, and test.

PS: There may be a better solution. Open to suggestions :)
cc: @ananthsub

cc @Borda @awaelchli @ananthsub @ninginthecloud @rohitgr7 @tchaton @akihironitta

@rohitgr7 rohitgr7 added feature Is an improvement or enhancement help wanted Open to be worked on labels Jan 16, 2021
@awaelchli
Copy link
Contributor

@rohitgr7 what's the difference between your proposal and the existing

trainer = Trainer(resume_from_checkpoint=x)
trainer.test()

?

@rohitgr7
Copy link
Contributor Author

trainer = Trainer(resume_from_checkpoint=x)
trainer.test()

in the recent PR the reload from resume_from_checkpoint was disabled completely since it was reloading the model state too also resume_from_checkpoint is meant to resume the training I guess. If we use this then while testing with different checkpoints we need to pass that checkpoint path in both Trainer(resume_from_checkpoint=ckpt) and trainer.test(ckpt_path=ckpt). I think with resume_from_checkpoint it will be easier to handle, it's just that this argument name sounds a bit misleading to me in case of testing.

@ananthsub
Copy link
Contributor

ananthsub commented Jan 24, 2021

in the recent PR the reload from resume_from_checkpoint was disabled completely since it was reloading the model state too

Isn't this a breaking API change? I commented on #5388 about this too. What happens to callbacks whose states also depend on the model?

@ananthsub
Copy link
Contributor

@rohitgr7 could we still call the checkpoint connector from setup_training inside run_evaluation, but add logic inside of restore to check for trainer.testing - if it's testing, we can load the model and callback states only and ignore the trainer states. this way we don't have to expose anything new in the trainer API. what do you think?

@stale
Copy link

stale bot commented Feb 27, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 27, 2021
@stale stale bot closed this as completed Mar 7, 2021
@awaelchli awaelchli reopened this Mar 7, 2021
@stale stale bot removed the won't fix This will not be worked on label Mar 7, 2021
@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Mar 7, 2021

for this, I'd suggest reloading the callback states as well along with the model state using the ckpt_path passed in .test. Open to suggestions!

@stale
Copy link

stale bot commented Apr 7, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 7, 2021
@awaelchli awaelchli removed the won't fix This will not be worked on label Apr 8, 2021
@stale
Copy link

stale bot commented May 8, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label May 8, 2021
@edenlightning edenlightning added the design Includes a design discussion label May 9, 2021
@stale stale bot removed the won't fix This will not be worked on label May 9, 2021
@edenlightning
Copy link
Contributor

@PyTorchLightning/core-contributors thoughts?

@edenlightning edenlightning added the discussion In a discussion stage label May 9, 2021
@carmocca
Copy link
Contributor

carmocca commented May 10, 2021

Does trainer.test(ckpt_path) not reload callback states if the ckpt_path checkpoint includes the states for callbacks?

Do we ever want to reload callback states but not the model state?

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented May 10, 2021

Does trainer.test(ckpt_path) not reload callback states if the ckpt_path checkpoint includes the states for callbacks?

it does not.

I guess one can simply use a hook for this.

def on_test_start(self):
    ckpt = load_checkpoint(self.tested_ckpt_path)
    self.trainer.on_load_checkpoint(ckpt)

@carmocca
Copy link
Contributor

@rohitgr7 but I think the Callback on_load_checkpoint also does not get called during fit

class MyCallback(Callback):
    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        return {"foo": True}

    def on_load_checkpoint(self, trainer, pl_module, callback_state):
        # DOES NOT GET CALLED
        print(callback_state)

def test_bug(tmpdir):
    model = BoringModel()

    trainer = Trainer(max_epochs=1, callbacks=[MyCallback()])
    trainer.fit(model)

    ckpt = str(tmpdir / "test.ckpt")
    trainer.save_checkpoint(ckpt)

    trainer = Trainer(resume_from_checkpoint=ckpt, max_epochs=2)
    trainer.fit(model)

I don't think we need to add a flag for this. If the state was saved, it should be reloaded

@rohitgr7
Copy link
Contributor Author

@carmocca to make it load the callback state you need to pass the callbacks in the re-run.

trainer = pl.Trainer(resume_from_checkpoint=ckpt, max_epochs=2, callbacks=[MyCallback()])

on_load_checkpoint is called only if you assign resume_from_checkpoint inside trainer.fit. During testing, it does not because we don't load any other state except the model's state_dict when a ckpt is passed to trainer.test.

To load the state of callbacks one can just manually call this trainer method, although it's just a workaround.

def on_test_start(self):
    ckpt = load_checkpoint(self.tested_ckpt_path)
    self.trainer.on_load_checkpoint(ckpt)

@carmocca
Copy link
Contributor

carmocca commented May 12, 2021

to make it load the callback state you need to pass the callbacks in the re-run.

Thanks. This could use an info message

although it's just a workaround.

I get it, but this should work as in fit

@rohitgr7
Copy link
Contributor Author

to make it load the callback state you need to pass the callbacks in the re-run.

Thanks. This could use an info message

although it's just a workaround.

I get it, but this should work as in fit

yes it should. Although just wondering this case.

test(model, ckpt_path)

should we load the callback states here since ckpt_path is passed or just simply ignore it since model is passed explictly.

@carmocca
Copy link
Contributor

should we load the callback states here since ckpt_path is passed or just simply ignore it since model is passed explictly.

ckpt_path does nothing if the model is passed

https://github.com/PyTorchLightning/pytorch-lightning/blob/20f63377f81f4771d3f128f979b3a0f9b8d219a7/pytorch_lightning/trainer/trainer.py#L569-L570

@stale
Copy link

stale bot commented Jun 16, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jun 16, 2021
@carmocca carmocca added this to the v1.5 milestone Jun 17, 2021
@stale stale bot removed the won't fix This will not be worked on label Jun 17, 2021
@tchaton
Copy link
Contributor

tchaton commented Oct 6, 2021

Hey @rohitgr7,

quick question: Do we have a use-case where we need to reload the callback states for testing ?

Best,
T.C

@tchaton tchaton added the priority: 1 Medium priority task label Oct 6, 2021
@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Oct 6, 2021

@tchaton I don't have one currently, but here's one from @ananthsub
#5161 (comment)

@tchaton
Copy link
Contributor

tchaton commented Oct 6, 2021

Dear @ananthsub,

Any chance the Exponential Moving Average Callback could be contributed to actually provide a proper use-case for this refactor. Otherwise, I believe it is a bad pattern to force refactor without any open source applications.

Best,
T.C

@awaelchli awaelchli modified the milestones: v1.5, v1.6 Nov 4, 2021
@hal-314
Copy link

hal-314 commented Jan 31, 2022

Hi @tchaton

I'm facing the same limitation as @ananthsub as I'm training with EMA. I expected that on_load_checkpoint to be callback with trainer.evaluate or trainer.predict. Lightning docs implies that they be always called when loading from a checkpoint. I believe that docs should mention that Callbacks.on_load_checkpoint is only called with trainer.fit.

You can find a working EMA callback here (apply this changes for MultiGPU support).

@carmocca carmocca added checkpointing Related to checkpointing trainer: test trainer: validate and removed discussion In a discussion stage design Includes a design discussion labels Feb 1, 2022
@carmocca carmocca modified the milestones: 1.6, future Feb 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement help wanted Open to be worked on priority: 1 Medium priority task trainer: test trainer: validate
Projects
None yet
Development

No branches or pull requests

7 participants