-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add image for better explanation to FSDP tutorial #2644
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2644
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 077d1c0 with merge base f05f050 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -46,6 +46,15 @@ At a high level FSDP works as follow: | |||
* Run reduce_scatter to sync gradients | |||
* Discard parameters. | |||
|
|||
The key insight behind full parameter sharding is that we can decompose the all-reduce operations in DDP into separate reduce-scatter and all-gather operations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure that this is the correct statement.
Even though an all-reduce can be decomposed as a reduce-scatter and all-gather, the current phrasing might suggest that DDP's gradient all-reduce is being decomposed into a gradient reduce-scatter and gradient all-gather. However, FSDP actually all-gathers parameters.
Whether or not this decomposition of all-reduce into reduce-scatter and all-gather is the key insight is not obvious to me. If we show this decomposition, we probably want more exposition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Adding this picture would demand a clear explanation.
I am unsure what to write. If you can suggest something or direct me to where I can read about this topic, that'd be very helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe something like the following:
One way to view FSDP's sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. In particular, FSDP reduce-scatters gradients such that each rank has a shard of the gradients in backward, updates the corresponding shard of the parameters in the optimizer step, and all-gathers them in the next forward.
< Figure >
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me!
Co-authored-by: Andrew Gu <31054793+awgu@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixes #2613
Description
The tutorial lacked an explanation for what's going on behind parameter sharding
Checklist
cc @wconstab @osalpekar @H-Huang @kwen2501 @sekyondaMeta @svekars @carljparker @NicolasHug @kit1980 @subramen