Skip to content

Commit

Permalink
[SPMD] Support mark_sharding on IRs (#5301)
Browse files Browse the repository at this point in the history
Summary:
This pull requests fixes the recompilation issue in xs.mark_sharding().
xtensor->GetXlaData() will compile the program if xtensor is an IR in order
to get the BackendData. I believe this is not intended given the error message
below suggests only data type xtensors are supported.

Test Plan:
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py
  • Loading branch information
alanwaketan committed Jul 13, 2023
1 parent caf5168 commit 6fe5cb9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
19 changes: 18 additions & 1 deletion test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ def test_xla_sharded_hlo_dump(self):
# scalar 5 should be replicated
self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo)

@unittest.skip("TODO(alanwaketan): Implement IR sharding to re-enable this.")
def test_2d_tensor_3d_mesh(self):
ct1 = torch.randn(16, 16, device='cpu')
ct2 = torch.randn(16, 16, device='cpu')
Expand Down Expand Up @@ -567,6 +566,24 @@ def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock):
self.assertEqual(hybrid_mesh.get_logical_mesh().tolist(),
[[0, 1], [2, 3], [4, 5], [6, 7]])

def test_mark_sharding_ir(self):
t1 = torch.randn(1, 128, device='cpu')
t2 = torch.randn(1, 128, device='cpu')
expected = t1 + t2

xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
actual = xt1 + xt2
xs.mark_sharding(actual, self._get_mesh((1, self.n_devices)), (0, 1))

if self.n_devices > 1:
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(actual))

self.assertTrue(torch.allclose(expected, actual.cpu()))


if __name__ == '__main__':
test = unittest.main()
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,16 @@ void InitXlaModuleBindings(py::module m) {
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For IR values, we directly attach the sharding spec to the xtensor.
if (xtensor->CurrentIrValue()) {
// TODO(alanwaketan): Do we want to check if there is any existing
// sharding spec? It seems okay to directly overwrite it.
xtensor->SetShardingSpec(*new_sharding_spec);
return;
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
Expand Down

0 comments on commit 6fe5cb9

Please sign in to comment.