diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 53974700fa6..af2db93a806 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -490,7 +490,6 @@ def test_xla_sharded_hlo_dump(self): # scalar 5 should be replicated self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo) - @unittest.skip("TODO(alanwaketan): Implement IR sharding to re-enable this.") def test_2d_tensor_3d_mesh(self): ct1 = torch.randn(16, 16, device='cpu') ct2 = torch.randn(16, 16, device='cpu') @@ -567,6 +566,24 @@ def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock): self.assertEqual(hybrid_mesh.get_logical_mesh().tolist(), [[0, 1], [2, 3], [4, 5], [6, 7]]) + def test_mark_sharding_ir(self): + t1 = torch.randn(1, 128, device='cpu') + t2 = torch.randn(1, 128, device='cpu') + expected = t1 + t2 + + xt1 = t1.to(xm.xla_device()) + xt2 = t2.to(xm.xla_device()) + actual = xt1 + xt2 + xs.mark_sharding(actual, self._get_mesh((1, self.n_devices)), (0, 1)) + + if self.n_devices > 1: + annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( + [str(i) for i in range(self.n_devices)])) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(actual)) + + self.assertTrue(torch.allclose(expected, actual.cpu())) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 9bb51cf1f05..339a7e79319 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1324,6 +1324,16 @@ void InitXlaModuleBindings(py::module m) { xtensor->shape(), static_cast(xtensor->GetDevice().type()))); + // For IR values, we directly attach the sharding spec to the xtensor. + if (xtensor->CurrentIrValue()) { + // TODO(alanwaketan): Do we want to check if there is any existing + // sharding spec? It seems okay to directly overwrite it. + xtensor->SetShardingSpec(*new_sharding_spec); + return; + } + + // For data, we need to deal with the data transfers between + // host and device. at::Tensor cpu_tensor; if (xtensor->CurrentTensorData().has_value()) { TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);