Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add API to assemble CPU shards to a sharded tensor #5630
Add API to assemble CPU shards to a sharded tensor #5630
Changes from all commits
13cfb7a
4413bc7
bf329c7
c03a567
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may not hold true, in case of uneven tiling. Let's make a note on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or, actually for the right most shard in each dim, we can add the sizes, as the padding is always on the last dims.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok we can't do that, since
"Input shard shape must include padding: " << shard.sizes()"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if I should revisit this. I decided on this approach because padding can cross over multiple devices, e.g. sharding a tensor with shape
(1, 2)
on the mesh(1, 4)
will have shards with shape(1, 1)
with no real data on the last two devices.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping @yeounoh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for the most part, user should specify the global shape. I am thinking that we should actually make it explicit, not optional... but if not, this way of handling, inferring the default shape is good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point... I chose to let it be inferred for convenience, since for e.g. distributed data loading, the shards are provided on CPU with the correct padded local shape, and deriving the global shape could be difficult for more sophisticated shardings.
I'll go ahead and land with it optional for now, since we're keeping the API private. Thanks Yeounoh!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for refactoring 👍