Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to assemble CPU shards to a sharded tensor #5630

Merged
merged 4 commits into from
Oct 5, 2023

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Sep 20, 2023

Currently, to convert CPU shards to an XLAShardedTensor, the API XLAShardedTensor::load_local_shards_ must be used. This loads CPU shards in-place to an existing sharded tensor. A more convenient method for use cases outside of distributed checkpointing is to directly assemble the shards into a global tensor on device.

This PR adds a private _XLAC API _get_global_tensor_from_cpu_shards to allow directly creating a sharded tensor from a list of CPU shards and an OpSharding generated by Mesh::get_op_sharding.

I currently plan to use this new API in a few places:

  • tpu.py::discover_master_worker_ip_address: place each host's IP into a global tensor and pull out the zeroth entry for worker 0. The existing IP discovery API currently doesn't work with SPMD, but it's needed for distributed checkpointing.
  • Distributed data loading: A per-device, sharding-aware dataloader can use this API to efficiently create device data from the loaded CPU shards.

@jonb377 jonb377 marked this pull request as ready for review September 20, 2023 20:41
torch_xla/runtime.py Outdated Show resolved Hide resolved
@jonb377 jonb377 self-assigned this Sep 21, 2023
// Set a default value for the global shape based on the sharding
// type.
if (sharding.type() == xla::OpSharding::OTHER) {
// Infer the global shape to be the shard shape scaled by the tiling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not hold true, in case of uneven tiling. Let's make a note on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or, actually for the right most shard in each dim, we can add the sizes, as the padding is always on the last dims.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok we can't do that, since "Input shard shape must include padding: " << shard.sizes()"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if I should revisit this. I decided on this approach because padding can cross over multiple devices, e.g. sharding a tensor with shape (1, 2) on the mesh (1, 4) will have shards with shape (1, 1) with no real data on the last two devices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @yeounoh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the most part, user should specify the global shape. I am thinking that we should actually make it explicit, not optional... but if not, this way of handling, inferring the default shape is good.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point... I chose to let it be inferred for convenience, since for e.g. distributed data loading, the shards are provided on CPU with the correct padded local shape, and deriving the global shape could be difficult for more sophisticated shardings.

I'll go ahead and land with it optional for now, since we're keeping the API private. Thanks Yeounoh!

@@ -87,6 +87,14 @@ def get_op_sharding(self,
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
"""
partition_spec = _translate_named_partition_spec(self, partition_spec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for refactoring 👍

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let some comments.

torch_xla/csrc/init_python_bindings.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/init_python_bindings.cpp Outdated Show resolved Hide resolved
@staticmethod
def from_cpu_shards(shards: List[torch.Tensor],
sharding: torch_xla._XLAC.OpSharding,
global_shape: torch.Size = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the benefit of providing a global_shape?

Copy link
Collaborator Author

@jonb377 jonb377 Sep 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the shards are padded, the global_shape can be used to remove padding from the global tensor.

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jonb377
Copy link
Collaborator Author

jonb377 commented Oct 5, 2023

Thanks @yeounoh and @alanwaketan for the review. I'll merge after a TPU CI run.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@jonb377 jonb377 merged commit 3913a77 into master Oct 5, 2023
18 checks passed
@jonb377 jonb377 deleted the jonbolin-assemble-shards branch October 5, 2023 21:23
jonb377 added a commit that referenced this pull request Oct 5, 2023
jonb377 added a commit that referenced this pull request Oct 5, 2023
qihqi pushed a commit that referenced this pull request Oct 10, 2023
* Add API to assemble CPU shards to a sharded tensor

* Handle replicated sharding

* Move validations into get_op_sharding

* Improve tests and error handling
qihqi pushed a commit that referenced this pull request Oct 10, 2023
zpcore pushed a commit that referenced this pull request Oct 19, 2023
* Add API to assemble CPU shards to a sharded tensor

* Handle replicated sharding

* Move validations into get_op_sharding

* Improve tests and error handling
zpcore pushed a commit that referenced this pull request Oct 19, 2023
ghpvnist pushed a commit to ghpvnist/xla that referenced this pull request Oct 31, 2023
* Add API to assemble CPU shards to a sharded tensor

* Handle replicated sharding

* Move validations into get_op_sharding

* Improve tests and error handling
ghpvnist pushed a commit to ghpvnist/xla that referenced this pull request Oct 31, 2023
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* Add API to assemble CPU shards to a sharded tensor

* Handle replicated sharding

* Move validations into get_op_sharding

* Improve tests and error handling
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Add API to assemble CPU shards to a sharded tensor

* Handle replicated sharding

* Move validations into get_op_sharding

* Improve tests and error handling
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Add API to assemble CPU shards to a sharded tensor

* Handle replicated sharding

* Move validations into get_op_sharding

* Improve tests and error handling
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants