-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add all-gather coalescing for FSDP/ZeRO1 #5950
Conversation
Also allow using reduce-scatter's scale param in FSDP. (revived #4145)
…ter tuple change without token
@jeffhataws let me know when you are done addressing comments, I will take another look |
@@ -295,6 +295,7 @@ def __init__( | |||
sharding_world_size: Optional[int] = None, | |||
shard_param_on_dim_0: bool = False, | |||
pin_layout_in_collective_ops: bool = True, | |||
coalesce_all_gather_ops: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind explaining the change in this file? I think coalesce_all_gather_ops
is always False
in our test, did you run into these issues with your own test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the coalesce_all_gather_ops
is True, the parameter shards are collected into a list and gathered in one all-gather coalesced command at the end (instead of all-gather one parameter at a time).
It is off by default to avoid changing existing behavior. The code is same as what we are using in our local fork.
ReduceContext cc_ctx = GetReduceContext(inputs); | ||
std::vector<xla::XlaOp> result(inputs.size()); | ||
|
||
for (auto& type_ctx : cc_ctx.contexts) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you want to assume there is only one type_ctx
, let's not use the for loop and GetReduceContext
at all. This way we don't need to handle the token per type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me check with others on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly lgtm beside the changes in FSDP. If we didn't change the default behavior of all-gather test should pass right?
I will look into reduce scatter one today, let's try to merge these two pr soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I think we should test allgather_coalesced using resnet on gpu to make sure we don't break it in the future. You can refer to existing test
Line 136 in 2c4983d
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 |
we can do that in a separate pr.
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived pytorch#4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived pytorch#4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived pytorch#4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived #4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived #4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
This PR adds all-gather coalescence support and use that in FSDP/ZeRO1 (replacing #5624). This PR is to be used in conjunction with openxla/xla#5740 .
A separate and related PR for reduce-scatter coalescence that also enables using reduce-scatter's scale param in FSDP is #5938.
This is a revival of #4145 . Will need to address the comments.