-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] Support manual sharding #6915
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! I've been wondering what manual sharding is for.
torch_xla/csrc/xla_sharding_util.h
Outdated
// The the returned tensors will be in 1:1 correspondence with the `devices` | ||
// vector, so the `i`th result will belong on the `i`th device. | ||
// the `tile_assignment`; MANUAL sharding result in shards where only the | ||
// first device holds the full data; the returned tensor shards vector is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only the first device holds the full data
Is this by definition of manual sharding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not by definition, but by our implementation choice. A more proper example would be a list of tensors (DTensor), where each tensor is an individual full shard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yeounoh Will that be replicated then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per our offline discussion, abstain from manual sharding on input data.
result.reserve(cpu_shards.size() / shards_per_tensor); | ||
for (int i = 0; i < cpu_shards.size(); i += shards_per_tensor) { | ||
std::vector<at::Tensor> cpu_shards = | ||
XlaDataToTensors(WrapXlaData(shard_handles), element_types); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling XlaDataToTensors
on each tensor individually will slow down d2h transfers for async checkpointing, since PjRt won't be able to fully utilize transfer parallelization.
Do we expect manually-sharded tensors to contain actual device data generally, or will they usually be IR? If just IR, maybe we can add an assertion to prevent access here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rather keep it functional for both cases -- shouldn't it be asynchronous anyway, not blocking the actual training run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting. I was not aware this performance optimization...
torch_xla/csrc/xla_sharding_util.cpp
Outdated
} else if ((sharding.type() == xla::OpSharding::MANUAL)) { | ||
// Just put the full tensor on the first device. | ||
shards[0] = tensor; | ||
shards.resize(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work for a compuatation, since we need to feed each device some input data?
e.g. based on your unit test, what happens if we run:
x = torch.randn(3, 2)
xx = x.to(xm.xla_device()) # xx is device data
xt = xs._mark_manual_sharding(xx)
ones = torch.ones(3, 2).to(xm.xla_device()) # ones is replicated to all devices
print(xt + ones) # What will happen here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XLA should assume that xt
is sharded manually, so expected to be plicated as well. The purpose of MANUAL
is to support custom kernel and prevent XLA to override the manual sharding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I would expect it behaves as a single device. Let me double check as well.
test/spmd/test_xla_sharding.py
Outdated
xt = xs._mark_manual_sharding(xx) | ||
|
||
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt.global_tensor]) | ||
self.assertIn('parameter(0), sharding={manual}', hlo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I leave the correctness review of distributed checkpointing with manual sharding to @jonb377 and his unit tests.
test/spmd/test_xla_sharding.py
Outdated
@@ -1100,6 +1100,26 @@ def test_global_mesh(self): | |||
|
|||
self.assertEqual(id(mesh), id(expected_mesh)) | |||
|
|||
def test__mark_manual_sharding(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. even though it's testing the _
prefixed api, let's keep it as test_mark_manual_sharding
Here is the new TPU CI run: https://github.com/pytorch/xla/actions/runs/8652176761 |
TPU test here: https://github.com/pytorch/xla/actions/runs/8654305716 |
All tests passed. I'm going to merge it. Let me know if I need to follow up on anything. |
Summary: This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly. Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding
Summary: This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly. Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding
Summary:
This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly.
Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out.
Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding