-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add all-gather coalescing for FSDP/ZeRO1 #5624
Conversation
@alanwaketan can you take a look? |
@jeffhataws Can you double check the CI failures? |
Thanks. Since it depends on openxla/xla#5740 I will need to take care of merging that first. So let's leave this open for now. |
79676e7
to
78fc6d3
Compare
Imported from GitHub PR #5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e by Junmin Hao <junminh@amazon.com>: lint fix -- aad3521 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes #5740 COPYBARA_INTEGRATE_REVIEW=#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
Imported from GitHub PR openxla/xla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of #58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159a1464efddebe9384e87ed6df504d89b2e by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e6d97f5f12b3d3c587bb5782d58e3554c5 by Junmin Hao <junminh@amazon.com>: lint fix -- aad352117ba950ac5ae62330e3980f4b5898a701 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e814524b88a474af5e4e904c0dd19841430b86 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a2a5b52180f9e9626173e6b67a78782960 by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278fc16c9f900782d32a92d40ecf548aea85 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes #5740 PiperOrigin-RevId: 573976449
78fc6d3
to
8f45cae
Compare
2e861ff
to
76a2f0f
Compare
@alanwaketan what's the best way to check this against the openxla with merged openxla/xla#5740? |
I'm working on a pin update. Will loop you in once that PR is up. |
@alanwaketan just want to check how things are going with the pin update. Also, how do I make the automatic checks to run on this PR? |
@alanwaketan @JackCaoG will you help ensure this is in 2.2? |
@jeffhataws can you rebase this pr? Then we can start reviewing it. Thanks! |
The pin update is completed. So, please rebase and then I will take a look. |
76a2f0f
to
1d8671f
Compare
Thanks alanwaketan. I have rebased. Please take a look. |
@alanwaketan @JackCaoG I am not sure how to reproduce/debug the above errors.
When I run
|
It wasn't hanging. After a while, it shows:
|
You didn't add any cpp test and our cpp test doesn't test distributed so that should be fine? I do see a bunch of python test failing in the CI with
in https://github.com/pytorch/xla/actions/runs/6896526398/job/18766614593?pr=5624 If I have to guess it is that the build pass but when we try to import the XLAC, we find that there are some functions that only has the header but missing in the cpp file? this one seems to be |
@alanwaketan is out next week, I will work with you to try to land this change before branch cut(or we will cherry-pick). |
Also allow using reduce-scatter's scale param in FSDP. (revived pytorch#4145)
1d8671f
to
57e6b36
Compare
Fixed. Removed the previous all_gather method by mistake. |
…tter Imported from GitHub PR openxla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e by Junmin Hao <junminh@amazon.com>: lint fix -- aad3521 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes openxla#5740 COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
Fixed the lint error. |
Yeah the code was ported from an old version of torch/xla so there were some merge errors. Plus the final version of openxla change openxla/xla#5740 doesn't have token support, so I need to make the corresponding change here. |
hmm, this error seems real
|
…ter tuple change without token
One of the CPU workflows failed with this:
|
hmm, seems like vm oom when building pytorch/xla... |
Let's ignore the CPU failure and focus on GPU for now. GPU failures seems real. |
All tests passing with Reduce-Scatter change separated out in #5938 . |
xla::XlaOp all_gather_result; | ||
if (pin_layout) { | ||
all_gather_result = xla::AllGather( | ||
xla::Tuple(inputs[0].builder(), type_ctx.second.ops), dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this means even if there is single element in the all-gather, we will wrap it inside the tuple.. I need to check with xla teams whether this has any speed implications.
} | ||
} else { | ||
result[0] = all_gather_result; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see how token is being used here, in all_reduce, we manually append it at the end of ops
for (auto& type_ctx : redux.contexts) {
xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first);
type_ctx.second.ops.push_back(token_op);
type_ctx.second.operand_shapes.push_back(
ShapeHelper::ShapeOfXlaOp(token_op));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw the GetOperandList
below, but I think this does not gurante when you have multiple types, each types has a token.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the usage of token.
For multiple types case, I think for now we should ensure same type in the list.
} | ||
return {result, torch::lazy::Value(node, inputs.size())}; | ||
} | ||
|
||
XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not why does it compile when below code still pass a single IR variable to torch::lazy::MakeNode<AllGather>
, while you change the constructor to take arrayRef
, maybe arrayRef
have a default constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few comments, overall I think it is better to implement a new BuildAllGatherCoalesced
instead of modifying the existing BuildAllGather
. The way you do it today will change the HLO generated for the single tensor all-gather. Given that branch cut is this Friday, I don't think we have time to do the performance benchmarking to make sure this is regression free, it is safer to add new features while no touch the existing logic.
oh I know what's going on. Can you create a branch from pytorch/xla directly instead of creating from a fork. Fork can not used our cache so compilation will take much longer and easier to fail. I already give you write access to the project so you should be able to create a new branch from our repo directly. |
Thanks. Changed to PR from a branch on pytorch/xla #5950 . |
…tter Imported from GitHub PR openxla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e by Junmin Hao <junminh@amazon.com>: lint fix -- aad3521 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes openxla#5740 COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
…tter Imported from GitHub PR openxla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e by Junmin Hao <junminh@amazon.com>: lint fix -- aad3521 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes openxla#5740 COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
(Replaced by #5950)
This PR adds all-gather coalescence support and use that in FSDP/ZeRO1. This PR is to be used in conjunction with openxla/xla#5740 .
A separate and related PR for reduce-scatter coalescence that also enables using reduce-scatter's scale param in FSDP is #5938.
This is a revival of #4145 . Will need to address the comments.