Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO1: Add bucketting logic to control the size of tensors for all-gather/reduce-scatter #6025

Merged
merged 15 commits into from
Mar 22, 2024

Conversation

jeffhataws
Copy link
Collaborator

This PR updates XLA ZeRO1 implementation to use allgather coalesed and reduce-scatter coalesced.

@jeffhataws jeffhataws force-pushed the jeffhataws_zero1_fixes2 branch 2 times, most recently from 7c3d92d to 84a509d Compare December 7, 2023 22:01
@jeffhataws jeffhataws force-pushed the jeffhataws_zero1_fixes2 branch from 84a509d to 285a766 Compare December 10, 2023 17:04
@jeffhataws jeffhataws force-pushed the jeffhataws_zero1_fixes2 branch from a453257 to 6022c91 Compare March 7, 2024 21:38
@JackCaoG
Copy link
Collaborator

Test crashed at torch_xla::tensor_methods::all_gather_coalesced_out(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, which seems to be a real issue.

@jeffhataws jeffhataws force-pushed the jeffhataws_zero1_fixes2 branch from a8f050e to 77b2ad1 Compare March 15, 2024 21:16
@JackCaoG
Copy link
Collaborator

@alanwaketan can you review this one as well since you also review the gradient bucketing one?

@jeffhataws jeffhataws force-pushed the jeffhataws_zero1_fixes2 branch from 173ef47 to 13965fd Compare March 19, 2024 23:06
@jeffhataws jeffhataws force-pushed the jeffhataws_zero1_fixes2 branch from 13965fd to 8586370 Compare March 19, 2024 23:08
@jeffhataws jeffhataws force-pushed the jeffhataws_zero1_fixes2 branch from ec4b1e0 to 675e7a1 Compare March 20, 2024 16:10
@jeffhataws jeffhataws requested a review from hgt312 March 20, 2024 16:16
Copy link
Collaborator

@hgt312 hgt312 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM

torch_xla/distributed/zero_redundancy_optimizer.py Outdated Show resolved Hide resolved
torch_xla/core/xla_model.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, minor comments.

@JackCaoG
Copy link
Collaborator

@jeffhataws is this ready for another round of review?

@jeffhataws
Copy link
Collaborator Author

@jeffhataws is this ready for another round of review?

I have a set of cleanup coming in an hour or so. I noticed that we have this code which was unique in all-gather. I will remove it, and have separate bucket_cap_mb for allgather and reduce-scatter in ZeRO1.

    if groups:
      divisor = len(groups[0]) if type(groups[0]) == list else len(groups)
    else:
      divisor = xrt_world_size()
    self._bucket_cap = self._bucket_cap / divisor

@jeffhataws
Copy link
Collaborator Author

jeffhataws commented Mar 21, 2024

@jeffhataws is this ready for another round of review?

@JackCaoG It is ready now for another round. Thanks.

@JackCaoG
Copy link
Collaborator

@jeffhataws Thanks for the refactoring work!

assert res.cpu().allclose(expected)

xm.rendezvous(
'test_reduce_scatter_list_input_output_bucketized, zero bucket size')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @JackCaoG , does rendezvous allow comma and space in the rendezvous key? How come this didn't error out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not a concern, we can merge this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the implementation of xla_rendezvous, I think tag got ignored so it doesn't really matter.

def xla_rendezvous(payload: bytes = b'',
ordinals: Optional[List[int]] = None,
tag: Optional[str] = None) -> List[bytes]:
"""Share `payload` with all replicas in `ordinals`.
`tag` is ignored except for logging.

@JackCaoG JackCaoG merged commit e75677f into master Mar 22, 2024
18 checks passed
JackCaoG pushed a commit that referenced this pull request Mar 22, 2024
…ther/reduce-scatter (#6025)

Co-authored-by: Rahul Solanki <rhsoln@amazon.com>
Co-authored-by: guangtai <guangtai@amazon.com>
Co-authored-by: Amithrajith Mamidala <amithrm@amazon.com>
lsy323 pushed a commit that referenced this pull request Mar 25, 2024
…or all-gather/reduce-scatter (#6025) (#6806)

Co-authored-by: jeffhataws <jthuynh@amazon.com>
Co-authored-by: Rahul Solanki <rhsoln@amazon.com>
Co-authored-by: guangtai <guangtai@amazon.com>
Co-authored-by: Amithrajith Mamidala <amithrm@amazon.com>
@jeffhataws jeffhataws deleted the jeffhataws_zero1_fixes2 branch November 22, 2024 23:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants