-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add API to assemble CPU shards to a sharded tensor #5630
Conversation
34cc1d3
to
a1e83ca
Compare
a1e83ca
to
13cfb7a
Compare
5b696f0
to
4413bc7
Compare
// Set a default value for the global shape based on the sharding | ||
// type. | ||
if (sharding.type() == xla::OpSharding::OTHER) { | ||
// Infer the global shape to be the shard shape scaled by the tiling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may not hold true, in case of uneven tiling. Let's make a note on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or, actually for the right most shard in each dim, we can add the sizes, as the padding is always on the last dims.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok we can't do that, since "Input shard shape must include padding: " << shard.sizes()"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if I should revisit this. I decided on this approach because padding can cross over multiple devices, e.g. sharding a tensor with shape (1, 2)
on the mesh (1, 4)
will have shards with shape (1, 1)
with no real data on the last two devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping @yeounoh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for the most part, user should specify the global shape. I am thinking that we should actually make it explicit, not optional... but if not, this way of handling, inferring the default shape is good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point... I chose to let it be inferred for convenience, since for e.g. distributed data loading, the shards are provided on CPU with the correct padded local shape, and deriving the global shape could be difficult for more sophisticated shardings.
I'll go ahead and land with it optional for now, since we're keeping the API private. Thanks Yeounoh!
@@ -87,6 +87,14 @@ def get_op_sharding(self, | |||
Return the OpSharding for the given partition spec. This is an expensive | |||
operation as the mesh grows, so the value is cached for reuse. | |||
""" | |||
partition_spec = _translate_named_partition_spec(self, partition_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for refactoring 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let some comments.
@staticmethod | ||
def from_cpu_shards(shards: List[torch.Tensor], | ||
sharding: torch_xla._XLAC.OpSharding, | ||
global_shape: torch.Size = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the benefit of providing a global_shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the shards are padded, the global_shape can be used to remove padding from the global tensor.
7bd0555
to
c03a567
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @yeounoh and @alanwaketan for the review. I'll merge after a TPU CI run. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
* Add API to assemble CPU shards to a sharded tensor * Handle replicated sharding * Move validations into get_op_sharding * Improve tests and error handling
* Add API to assemble CPU shards to a sharded tensor * Handle replicated sharding * Move validations into get_op_sharding * Improve tests and error handling
* Add API to assemble CPU shards to a sharded tensor * Handle replicated sharding * Move validations into get_op_sharding * Improve tests and error handling
)" (pytorch#5680) This reverts commit 3913a77.
* Add API to assemble CPU shards to a sharded tensor * Handle replicated sharding * Move validations into get_op_sharding * Improve tests and error handling
)" (pytorch#5680) This reverts commit 3913a77.
* Add API to assemble CPU shards to a sharded tensor * Handle replicated sharding * Move validations into get_op_sharding * Improve tests and error handling
* Add API to assemble CPU shards to a sharded tensor * Handle replicated sharding * Move validations into get_op_sharding * Improve tests and error handling
Currently, to convert CPU shards to an XLAShardedTensor, the API
XLAShardedTensor::load_local_shards_
must be used. This loads CPU shards in-place to an existing sharded tensor. A more convenient method for use cases outside of distributed checkpointing is to directly assemble the shards into a global tensor on device.This PR adds a private _XLAC API
_get_global_tensor_from_cpu_shards
to allow directly creating a sharded tensor from a list of CPU shards and an OpSharding generated byMesh::get_op_sharding
.I currently plan to use this new API in a few places: