-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ZeRO1: Add bucketting logic to control the size of tensors for all-gather/reduce-scatter #6025
Conversation
7c3d92d
to
84a509d
Compare
84a509d
to
285a766
Compare
a453257
to
6022c91
Compare
Test crashed at |
a8f050e
to
77b2ad1
Compare
@alanwaketan can you review this one as well since you also review the gradient bucketing one? |
173ef47
to
13965fd
Compare
13965fd
to
8586370
Compare
ec4b1e0
to
675e7a1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly lgtm, minor comments.
@jeffhataws is this ready for another round of review? |
I have a set of cleanup coming in an hour or so. I noticed that we have this code which was unique in all-gather. I will remove it, and have separate bucket_cap_mb for allgather and reduce-scatter in ZeRO1.
|
@JackCaoG It is ready now for another round. Thanks. |
@jeffhataws Thanks for the refactoring work! |
assert res.cpu().allclose(expected) | ||
|
||
xm.rendezvous( | ||
'test_reduce_scatter_list_input_output_bucketized, zero bucket size') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @JackCaoG , does rendezvous allow comma and space in the rendezvous key? How come this didn't error out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is not a concern, we can merge this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the implementation of xla_rendezvous
, I think tag got ignored so it doesn't really matter.
xla/torch_xla/core/xla_model.py
Lines 1110 to 1115 in 782f05d
def xla_rendezvous(payload: bytes = b'', | |
ordinals: Optional[List[int]] = None, | |
tag: Optional[str] = None) -> List[bytes]: | |
"""Share `payload` with all replicas in `ordinals`. | |
`tag` is ignored except for logging. |
…ther/reduce-scatter (#6025) Co-authored-by: Rahul Solanki <rhsoln@amazon.com> Co-authored-by: guangtai <guangtai@amazon.com> Co-authored-by: Amithrajith Mamidala <amithrm@amazon.com>
This PR updates XLA ZeRO1 implementation to use allgather coalesed and reduce-scatter coalesced.