diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index d41f89b5f6d..db303302e09 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -42,6 +42,20 @@ def test_xla_sharded_tensor(self): # TODO(244003536) add more tests for XLAShardedTensror. self.assertTrue(isinstance(xst1, XLAShardedTensor)) + def test_xla_sharded_tensor_repr(self): + xt = torch.randn(128, 128).to(xm.xla_device()) + model = self.SimpleLinear().to(xm.xla_device()) + + mesh = self._get_mesh((1, self.n_devices)) + partition_spec = (0, 1) + xst = xs.mark_sharding(xt, mesh, partition_spec) + self.assertTrue(isinstance(xst, XLAShardedTensor)) + + xt_output = model(xt) + self.assertTrue('XLAShardedTensor' not in str(xt_output)) + xst_output = model(xst) + self.assertTrue('XLAShardedTensor' in str(xst_output)) + def test_sharded_tensor_debug_info(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 8e2e89f75f4..2945502dcc2 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -142,6 +142,10 @@ def sharding_type(self) -> 'ShardingType': return ShardingType(sharding_type) def __repr__(self): + if not hasattr(self, "global_tensor"): + # materialize a copy of sharded global_tensnor and keep the actual data + # sharded on the XLA devices. + return str(self.cpu()) return f"XLAShardedTensor({self.global_tensor})" @classmethod