-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FullyShardedDataParallel wrapped models not being unwrapped, leading to incorrect checkpoints. #13500
Comments
I should note this issue occurs when applying the FSDP strategy to an otherwise normal model. Not a manually sharded at initialization time model. |
We are experiencing lower accuracy of models trained with fsdp as opposed to ddp, do you think that could be the same underlying issue? |
See if switching to ddp_sharded significantly improves the issue. That one does not have the same bug but should be very similar otherwise from my understanding. It’s basically the previous version of fsdp.
The way to diagnose the issue is poor quality checkpoints. I saw loss curves look the same or similar during training, but then if I had to continue a run from the last checkpoint the loss would basically start from the same point as it has at step 0 rather than a higher step. Epoch and other state would properly reload though. Ddp sharded didn’t have that behavior with the same code otherwise. Also manual inspection of results from saved checkpoints looked bad even though it looked like the models were training.
…On Jul 12, 2022, 3:25 AM -0700, Zygmunt Łenyk ***@***.***>, wrote:
We are experiencing lower accuracy of models trained with fsdp as opposed to ddp, do you think that could be the same underlying issue?
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
Thanks for the information, I'll try ddp_sharded. So far we definitely saw big difference between fsdp and ddp, however I believe this difference was also visible in training curves, so it might be a different issue. |
This solution looks good, hope it fixes the problem! #13502 |
Did that fix your issue @zlenyk, out of curiosity? |
Sorry for late response, queue of experiments and so on... |
Hey @jstjohn |
🐛 Bug
When training in parallel with the
fsdp
strategy, the saved checkpoints are somehow messed up. When I try to resume training from those, the epoch number is properly resumed, but the loss spikes dramatically, like as if it went back to an initial/random state. When I do the same train/checkpoint/resume loop withddp_sharded
I do not have this issue and the checkpoint resumes with a similar loss to where it left off. I further saw that when I point a model with strategyfsdp
at a checkpoint saved withddp_sharded
it also resumes with a reasonable loss that is roughly at the previous level. This suggests thatfsdp
loads a checkpoint ok, but there is something wrong with how it saves checkpoints in parallel. Conversely when I resume usingddp_sharded
from anfsdp
saved checkpoint, the loss is dramatically worse as if weights were randomly initialized, further suggesting that the issue is with how weights are saved infsdp
. Knowing all of this, I am able to just switch to usingddp_sharded
but this seems like a really nasty bug that could cause other people headaches so I wanted to make sure it was known.The fix seems to be to make sure to unwrap the
FullyShardedDataParallel
wrapper. One key difference between thefsdp
strategy implementation and theddp_sharded
strategy implementation is thatddp_sharded
overridesself.lightning_module
and does calls a customunwrap_...
function which unwraps theShardedDataParallel
layer prior to calling the sharedunwrap_lightning_module(...)
function.fsdp
doesn't do any of this, and it defaults to the method implemented inParallelStrategy.lightning_module
which only calls theunwrap_lightning_module(...)
function.I am going to open a PR and link it here which makes
unwrap_lightning_module(...)
aware ofFullyShardedDataParallel
(both flavors) as well asShardedDataParallel
so that all of the strategies that use one of those wrappers would benefit. Also in the future hopefully that will be a piece of code that is noticed which needs to be modified as new wrappers are added.To Reproduce
--strategy fsdp
. Note the loss at the beginning and make sure it drops.fsdp native
strategy, whatever that is called, is also broken. Maybe others.--strategy ddp_sharded
and note that the loss resumes from where it left off.Expected behavior
Model training continues when resuming.
Environment
Additional context
cc @SeanNaren @awaelchli @rohitgr7 @akihironitta
The text was updated successfully, but these errors were encountered: