Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FullyShardedDataParallel wrapped models not being unwrapped, leading to incorrect checkpoints. #13500

Closed
jstjohn opened this issue Jul 1, 2022 · 8 comments · Fixed by #13738
Assignees
Labels
bug Something isn't working strategy: fairscale fsdp (removed) Fully Sharded Data Parallel
Milestone

Comments

@jstjohn
Copy link
Contributor

jstjohn commented Jul 1, 2022

🐛 Bug

When training in parallel with the fsdp strategy, the saved checkpoints are somehow messed up. When I try to resume training from those, the epoch number is properly resumed, but the loss spikes dramatically, like as if it went back to an initial/random state. When I do the same train/checkpoint/resume loop with ddp_sharded I do not have this issue and the checkpoint resumes with a similar loss to where it left off. I further saw that when I point a model with strategy fsdp at a checkpoint saved with ddp_sharded it also resumes with a reasonable loss that is roughly at the previous level. This suggests that fsdp loads a checkpoint ok, but there is something wrong with how it saves checkpoints in parallel. Conversely when I resume using ddp_sharded from an fsdp saved checkpoint, the loss is dramatically worse as if weights were randomly initialized, further suggesting that the issue is with how weights are saved in fsdp. Knowing all of this, I am able to just switch to using ddp_sharded but this seems like a really nasty bug that could cause other people headaches so I wanted to make sure it was known.

The fix seems to be to make sure to unwrap the FullyShardedDataParallel wrapper. One key difference between the fsdp strategy implementation and the ddp_sharded strategy implementation is that ddp_sharded overrides self.lightning_module and does calls a custom unwrap_... function which unwraps the ShardedDataParallel layer prior to calling the shared unwrap_lightning_module(...) function. fsdp doesn't do any of this, and it defaults to the method implemented in ParallelStrategy.lightning_module which only calls the unwrap_lightning_module(...) function.

I am going to open a PR and link it here which makes unwrap_lightning_module(...) aware of FullyShardedDataParallel (both flavors) as well as ShardedDataParallel so that all of the strategies that use one of those wrappers would benefit. Also in the future hopefully that will be a piece of code that is noticed which needs to be modified as new wrappers are added.

To Reproduce

  1. Train a model in parallel that saves checkpoints for a few epochs, use --strategy fsdp. Note the loss at the beginning and make sure it drops.
  2. Resume a model using any strategy from one of those saved checkpoints and note that the loss is similar to the beginning of training. Based on code I would guess that the fsdp native strategy, whatever that is called, is also broken. Maybe others.
  3. Repeat 1,2 but this time use --strategy ddp_sharded and note that the loss resumes from where it left off.

Expected behavior

Model training continues when resuming.

Environment

  • CUDA:
    • GPU:
      • NVIDIA A100 80GB PCIe
      • NVIDIA A10
      • NVIDIA A10
    • available: True
    • version: 11.5
  • Packages:
    • numpy: 1.22.3
    • pyTorch_debug: False
    • pyTorch_version: 1.11.0
    • pytorch-lightning: 1.6.3
    • tqdm: 4.64.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.9.12
    • version: Training metrics #100-Ubuntu SMP Fri Sep 24 14:50:10 UTC 2021

Additional context

cc @SeanNaren @awaelchli @rohitgr7 @akihironitta

@jstjohn jstjohn added the needs triage Waiting to be triaged by maintainers label Jul 1, 2022
@awaelchli awaelchli added bug Something isn't working strategy: fairscale fsdp (removed) Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Jul 1, 2022
@awaelchli awaelchli added this to the pl:1.6.x milestone Jul 1, 2022
@jstjohn
Copy link
Contributor Author

jstjohn commented Jul 1, 2022

I should note this issue occurs when applying the FSDP strategy to an otherwise normal model. Not a manually sharded at initialization time model.

@zlenyk
Copy link

zlenyk commented Jul 12, 2022

We are experiencing lower accuracy of models trained with fsdp as opposed to ddp, do you think that could be the same underlying issue?

@jstjohn
Copy link
Contributor Author

jstjohn commented Jul 12, 2022 via email

@zlenyk
Copy link

zlenyk commented Jul 12, 2022

Thanks for the information, I'll try ddp_sharded. So far we definitely saw big difference between fsdp and ddp, however I believe this difference was also visible in training curves, so it might be a different issue.

@jstjohn
Copy link
Contributor Author

jstjohn commented Jul 14, 2022

This solution looks good, hope it fixes the problem! #13502

@jstjohn
Copy link
Contributor Author

jstjohn commented Jul 14, 2022

Thanks for the information, I'll try ddp_sharded. So far we definitely saw big difference between fsdp and ddp, however I believe this difference was also visible in training curves, so it might be a different issue.

Did that fix your issue @zlenyk, out of curiosity?

@zlenyk
Copy link

zlenyk commented Jul 19, 2022

Sorry for late response, queue of experiments and so on...
So DDP_sharded is running out of memory in our experiments, so we need to experiment a little more to squeeze our model into the memory. Therefore I don't yet have answer if model performance is on par.

@awaelchli
Copy link
Contributor

Hey @jstjohn
Sorry for the delay. In #13738 I've nuked the complicated, error-prone unwrap logic. I will open it soon for review, but currently trying to verify I have solved your use case. If you want to give it a try as well on that branch and see if it works well for you, that would be really cool, even though I suspect you have probably moved on with your own fix already.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment