-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Load callback states while testing. #5542
Comments
@rohitgr7 what's the difference between your proposal and the existing trainer = Trainer(resume_from_checkpoint=x)
trainer.test() ? |
trainer = Trainer(resume_from_checkpoint=x)
trainer.test() in the recent PR the reload from |
Isn't this a breaking API change? I commented on #5388 about this too. What happens to callbacks whose states also depend on the model? |
@rohitgr7 could we still call the checkpoint connector from |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
for this, I'd suggest reloading the callback states as well along with the model state using the |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
@PyTorchLightning/core-contributors thoughts? |
Does Do we ever want to reload callback states but not the model state? |
it does not. I guess one can simply use a hook for this. def on_test_start(self):
ckpt = load_checkpoint(self.tested_ckpt_path)
self.trainer.on_load_checkpoint(ckpt) |
@rohitgr7 but I think the Callback class MyCallback(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return {"foo": True}
def on_load_checkpoint(self, trainer, pl_module, callback_state):
# DOES NOT GET CALLED
print(callback_state)
def test_bug(tmpdir):
model = BoringModel()
trainer = Trainer(max_epochs=1, callbacks=[MyCallback()])
trainer.fit(model)
ckpt = str(tmpdir / "test.ckpt")
trainer.save_checkpoint(ckpt)
trainer = Trainer(resume_from_checkpoint=ckpt, max_epochs=2)
trainer.fit(model) I don't think we need to add a flag for this. If the state was saved, it should be reloaded |
@carmocca to make it load the callback state you need to pass the callbacks in the re-run. trainer = pl.Trainer(resume_from_checkpoint=ckpt, max_epochs=2, callbacks=[MyCallback()])
To load the state of callbacks one can just manually call this trainer method, although it's just a workaround. def on_test_start(self):
ckpt = load_checkpoint(self.tested_ckpt_path)
self.trainer.on_load_checkpoint(ckpt) |
Thanks. This could use an info message
I get it, but this should work as in |
yes it should. Although just wondering this case. test(model, ckpt_path) should we load the callback states here since ckpt_path is passed or just simply ignore it since model is passed explictly. |
ckpt_path does nothing if the model is passed |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
Hey @rohitgr7, quick question: Do we have a use-case where we need to reload the callback states for testing ? Best, |
@tchaton I don't have one currently, but here's one from @ananthsub |
Dear @ananthsub, Any chance the Best, |
Hi @tchaton I'm facing the same limitation as @ananthsub as I'm training with EMA. I expected that You can find a working EMA callback here (apply this changes for MultiGPU support). |
🚀 Feature
Load callback states while testing.
Motivation
#5161 (comment)
Pitch
Two possible API changes:
with an additional argument
restore_states
:or without any additional argument:
Alternatives
Alternatively, one can just reload checkpoints manually, call
on_load_checkpoint
for all the callbacks manually, and test.PS: There may be a better solution. Open to suggestions :)
cc: @ananthsub
cc @Borda @awaelchli @ananthsub @ninginthecloud @rohitgr7 @tchaton @akihironitta
The text was updated successfully, but these errors were encountered: