Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Support manual sharding #6915

Merged
merged 7 commits into from
Apr 12, 2024
Merged

[SPMD] Support manual sharding #6915

merged 7 commits into from
Apr 12, 2024

Conversation

alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented Apr 10, 2024

Summary:
This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly.

Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding

@alanwaketan alanwaketan requested review from yeounoh and jonb377 April 10, 2024 23:08
@alanwaketan alanwaketan self-assigned this Apr 10, 2024
Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! I've been wondering what manual sharding is for.

// The the returned tensors will be in 1:1 correspondence with the `devices`
// vector, so the `i`th result will belong on the `i`th device.
// the `tile_assignment`; MANUAL sharding result in shards where only the
// first device holds the full data; the returned tensor shards vector is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only the first device holds the full data

Is this by definition of manual sharding?

Copy link
Contributor

@yeounoh yeounoh Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not by definition, but by our implementation choice. A more proper example would be a list of tensors (DTensor), where each tensor is an individual full shard.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yeounoh Will that be replicated then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per our offline discussion, abstain from manual sharding on input data.

result.reserve(cpu_shards.size() / shards_per_tensor);
for (int i = 0; i < cpu_shards.size(); i += shards_per_tensor) {
std::vector<at::Tensor> cpu_shards =
XlaDataToTensors(WrapXlaData(shard_handles), element_types);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling XlaDataToTensors on each tensor individually will slow down d2h transfers for async checkpointing, since PjRt won't be able to fully utilize transfer parallelization.

Do we expect manually-sharded tensors to contain actual device data generally, or will they usually be IR? If just IR, maybe we can add an assertion to prevent access here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rather keep it functional for both cases -- shouldn't it be asynchronous anyway, not blocking the actual training run?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. I was not aware this performance optimization...

} else if ((sharding.type() == xla::OpSharding::MANUAL)) {
// Just put the full tensor on the first device.
shards[0] = tensor;
shards.resize(1);
Copy link
Collaborator

@jonb377 jonb377 Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work for a compuatation, since we need to feed each device some input data?

e.g. based on your unit test, what happens if we run:

    x = torch.randn(3, 2)
    xx = x.to(xm.xla_device())  # xx is device data
    xt = xs._mark_manual_sharding(xx)

    ones = torch.ones(3, 2).to(xm.xla_device()) # ones is replicated to all devices
    print(xt + ones)  # What will happen here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA should assume that xt is sharded manually, so expected to be plicated as well. The purpose of MANUAL is to support custom kernel and prevent XLA to override the manual sharding.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I would expect it behaves as a single device. Let me double check as well.

xt = xs._mark_manual_sharding(xx)

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt.global_tensor])
self.assertIn('parameter(0), sharding={manual}', hlo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I leave the correctness review of distributed checkpointing with manual sharding to @jonb377 and his unit tests.

@@ -1100,6 +1100,26 @@ def test_global_mesh(self):

self.assertEqual(id(mesh), id(expected_mesh))

def test__mark_manual_sharding(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. even though it's testing the _ prefixed api, let's keep it as test_mark_manual_sharding

@alanwaketan
Copy link
Collaborator Author

Here is the new TPU CI run: https://github.com/pytorch/xla/actions/runs/8652176761

@alanwaketan
Copy link
Collaborator Author

@alanwaketan
Copy link
Collaborator Author

All tests passed. I'm going to merge it. Let me know if I need to follow up on anything.

@alanwaketan alanwaketan merged commit e5513ff into master Apr 12, 2024
20 checks passed
lausannel pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary:
This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly.

Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding
baoleai pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary:
This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly.

Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants