-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[refactor] Move save_function to accelerator 1/n [DeepSpeed] #6689
Conversation
Hello @tchaton! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-03-29 13:54:43 UTC |
This is also motivated by the need for the training type plugin to handle saving of checkpoint in sharded environments, where the model has been sharded onto multiple processes. Eventually the same should happen for loading of checkpoints if we're using |
Codecov Report
@@ Coverage Diff @@
## master #6689 +/- ##
========================================
- Coverage 91% 82% -9%
========================================
Files 192 192
Lines 12238 13106 +868
========================================
- Hits 11152 10759 -393
- Misses 1086 2347 +1261 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we keep the current logic and just call:
(in dump_checkpoint
)
model = self.trainer.lightning_module
checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': self.trainer.accelerator.model_state(model),
}
Just as we do for the optimizer?
pytorch_lightning/plugins/training_type/training_type_plugin.py
Outdated
Show resolved
Hide resolved
I'm probably missing some pieces here, but in the DeepSpeed Plugin we need to save the checkpoints differently to the standard saving of one file from rank 0. Each process needs to save it's weights to a directory, thus this logic needs to be managed by either the training type plugin, or a separate class! |
We should then make sure we have the requirements clear to properly refactor things. If it's just about that, we could extract the model state outside of dump checkpoint: (in model_state = self.trainer.accelerator.model_state(model)
# dump states as a checkpoint dictionary object
checkpoint = self.dump_checkpoint(model_state, weights_only)
if self.trainer.is_global_zero:
checkpoint = self.on_save(checkpoint)
try:
atomic_save(checkpoint, filepath)
... But I'm sure there are other requirements, so we should state them clearly |
The code you shared doesn't adapt so well with DeepSpeed and FSDP I guess.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
we should spec out what the path for loading will be too. i think it's going to be much trickier especially with resharding to determine what the execution order should look like. since we're saving to a directory, what does
last.ckpt
mean anymore? do we havelast-{rank}.ckpt
? how will this be specified on the trainer params? do we now take a directory instead of a specific path? what happens if that directory contains other checkpoint files that are unrelated to the one at hand? would we accidentally load those too? -
the main difference with deepspeed and FSDP is that the model and optimizer states can vary across ranks. should the other pieces of the checkpoint dict (trainer progress, callback states) remain inside of the checkpoint connector? or should everything be part of the trainer type plugin?
-
not specific to this PR, but this is an existing inefficiency: we we potentially call this multiple times inside of the checkpoint callback (1 for save top K, 1 for save last) if the existing saved file is available, we should copy over to
last.ckpt
instead of going through the whole dumping again -
Could we get rid of this and rely directly on the training type plugin? https://github.com/PyTorchLightning/pytorch-lightning/blob/21fc5eb21e6db07bcc222afa4204b3d5fb5be323/pytorch_lightning/callbacks/model_checkpoint.py#L217
i think this was there for mocks before but I don't think it's needed now
I'll make an RFC for us to properly track the changes here, and explain the motivations since it seems its too low level for us to start at |
Hey @ananthsub ,
Currently in https://github.com/PyTorchLightning/pytorch-lightning/pull/6546/files
In Feat/ds update PR: https://github.com/PyTorchLightning/pytorch-lightning/blob/924d9e2c40a7a0af7766d1131b1b963c95b721a3/pytorch_lightning/plugins/training_type/deepspeed.py#L458
IMO, we will clean this API when FSDP will be at the same stage than DeepSpeed around saving / reloading. Best, |
To track the conversation on what the API should look like I made an RFC here: #6691 |
Yes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving after speaking with @SeanNaren
Current goal is to keep it simple and not go for just one state-dict per process and instead save the entire checkpoint per-process.
the people who are using this know what they're doing (they've already had to read the documentation to use the model parallel hook). i get the eventual goal of reducing duplication, but i think regardless of the decision you'll have X amount of state dicts saved to disk, so those are just optimizations imo. If we wanted to be really specific we could do:
pl.Trainer(
plugin=DeepSpeedPlugin(stage=3),
callbacks=[ShardedCheckpoint(...)]
)
Ideally, the end result for me is:
last/
trainer_state.pt # contains the basic checkpoint data, saved by the checkpoint connector
rank_{rank}_model_states.pt # contains the model state, saved by the training_type_plugin
rank_{rank}_optimizer_states.pt # contains the optimizer state, saved by the training_type_plugin
And a ShardedCheckpoint
would be used to read this special checkpoint, where we may convert any present ModelCheckpoint
into ShardedCheckpoint
pytorch_lightning/plugins/training_type/training_type_plugin.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
"""Save model/training states as a checkpoint file through state-dump and file-write. | ||
|
||
Args: | ||
trainer: PyTorch Lightning Trainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trainer: PyTorch Lightning Trainer | |
checkpoint: dict containing model and trainer state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done !
What does this PR do?
This PR refactor save_checkpoint to be trainer - checkpoint_connector - accelerator - training type plugin responsibility.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃