-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add required states for resumed ModelCheckpoint GC #10995
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ORippler thanks for this fix. To avoid regression again, we need a test for this.
Do you think the following does reflect this issue sufficiently (If so, feel free to take it and commit it directly to your branch):
def test_model_checkpoint_attributes(tmpdir):
seed_everything()
model = LogInTwoMethods()
epochs = 2
checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
limit_train_batches=10,
limit_val_batches=10,
max_epochs=epochs,
logger=False,
)
trainer.fit(model)
checkpoint = torch.load(os.path.join(tmpdir, 'last.ckpt'))['callbacks'][checkpoint_callback.state_key]
for k in ("best_models, kth_best_model_path", "kth_value", "last_model_path"):
assert checkpoint[k] == getattr(checkpoint_callback, k)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
How would this integrate with the different functionality tests for Is this not something we want to test also? Am a bit confused here. |
Note that we do not yet check for proper loading/reinstantiation of ModelCheckpooint based on the ckpt written to disk
for more information, see https://pre-commit.ci
I added your test and expanded it to check whether a freshly instantiated Off-Note: Do we have a test that compares for equivalence of results generated by one continuous training run and an interrupted one that is resumed by passing the checkpoint to |
@ORippler we have https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/trainer/test_trainer.py#L399 which doesn't check the results. However, this is on purpose as the checkpoint does not include any random state and thus continuing from the checkpoint doesn't have to yield the exact same results (different random states when using the global rng for example) until now (this is currently in development). cc @tchaton to add a similar test once fault tolerance is ready |
Hey @ORippler, Yes, we have multiple tests checking the weights are the same before and after for Fault Tolerance. Here they are: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/utilities/test_auto_restart.py |
`ModelCheckpoint` is configured to save after every epoch, but `trainer.fit` is called with `max_steps = 1` Note there may be a better way of doing this, where `ModelCheckpoint` is called after `training_step`
Codecov Report
@@ Coverage Diff @@
## master #10995 +/- ##
========================================
- Coverage 92% 88% -4%
========================================
Files 177 177
Lines 16502 16560 +58
========================================
- Hits 15173 14604 -569
- Misses 1329 1956 +627 |
* First save, then load ckpt. * Instantiate ModelCheckpoint twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
"best_k_models": self.best_k_models, | ||
"kth_best_model_path": self.kth_best_model_path, | ||
"kth_value": self.kth_value, | ||
"last_model_path": self.last_model_path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed an issue with doing this.
Since we save each "ModelCheckpoint" mode sequentally, these attributes will not be correct depending on the order if more than 1 mode triggers a save for the same global step:
Currently, a "top-k" checkpoint will not include the last_model_path
path even if it's saved right after for this global step.
I'm not sure what would be the best solution here. I think we should start recommending multiple ModelCheckpoint
instances as a best practice because these interactions between flags can be unintuitive.
cc @awaelchli @ananthsub @jjenniferdai
Related to #4335 and #11805 (comment)
What does this PR do?
Fixes #4911
Related: #5090
Currently, when resuming training the internal states required for continued ModelCheckpointing are neither saved nor restored. This leads to the fact that
k
new checkpoints are always generated due to this check. These new checkpoints are properly gced/compared to against, but the old ones are not.Note that this PR does not handle overrides of
monitor
,dirpath
ormode
, as also referred to in #4911Does your PR introduce any breaking changes? If yes, please list them.
It might be that resuming training fails now if it did not fail before, if the paths were changed in the mean time (refer also #4911). I did not check/test for this, but confirmed that resumption of GC now works properly
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃
cc @carmocca @awaelchli @ninginthecloud @jjenniferdai