From 13cfb7ac8af4433e25ee50767727589b1f03faaa Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Wed, 20 Sep 2023 06:46:37 +0000 Subject: [PATCH 1/4] Add API to assemble CPU shards to a sharded tensor --- test/spmd/test_xla_sharding.py | 59 ++++++++++++++++++++ torch_xla/csrc/init_python_bindings.cpp | 55 ++++++++++++++++++ torch_xla/csrc/xla_sharding_util.cpp | 3 +- torch_xla/csrc/xla_sharding_util.h | 3 +- torch_xla/experimental/xla_sharded_tensor.py | 13 +++++ torch_xla/runtime.py | 6 ++ 6 files changed, 137 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 07fa79bc658..5a644b9cd7b 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -866,6 +866,65 @@ def test_op_sharding_cache(self): xs.mark_sharding(v, mesh, (0, None)) self.assertEqual(met.counter_value("CreateOpSharding"), 2) + def test_from_cpu_shards(self): + # Create an OpSharding with all devices on a single axis + mesh = self._get_mesh((self.n_devices,)) + op_sharding = mesh.get_op_sharding((0,)) + + # Infer the global shape from the OpSharding + shards = [torch.LongTensor([i]) for i in range(self.n_devices)] + global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding) + self.assertTrue( + torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices))) + + # Test incorrect number of shards + with self.assertRaises(RuntimeError): + XLAShardedTensor.from_cpu_shards(shards[:-1], op_sharding) + + # Test an invalid global shape - too many values. + with self.assertRaises(RuntimeError): + bad_size = torch.Size((2 * self.n_devices,)) + XLAShardedTensor.from_cpu_shards(shards, op_sharding, bad_size) + + if self.n_devices > 1: + # Test an invalid global shape - incorrect rank + with self.assertRaises(RuntimeError): + bad_size = torch.Size((self.n_devices // 2, 2)) + XLAShardedTensor.from_cpu_shards(shards, op_sharding, bad_size) + + # Test a valid global shape - restrict the number of meaningful values + # to 1, treating the rest as padding. + global_shape = torch.Size((1,)) + global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding, + global_shape) + self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1))) + + # Use a 2d mesh + mesh_2d = self._get_mesh((self.n_devices // 2, 2)) + shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)] + op_sharding = mesh_2d.get_op_sharding((0, 1)) + global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding) + expected = torch.arange(self.n_devices).reshape(2, self.n_devices // 2) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + # Use partial replication. With eight devices, the mesh and shards are: + # mesh_2d shards + # | TPU:0 | TPU:1 | | 0 | 1 | + # | TPU:2 | TPU:3 | | 0 | 1 | + # | TPU:4 | TPU:5 | | 0 | 1 | + # | TPU:6 | TPU:7 | | 0 | 1 | + # + # Partial replication along the 0th axis represents a global tensor + # of torch.Tensor([[0, 1]]). + shards = [torch.LongTensor([[i]]) for i in range(2)] * ( + self.n_devices // 2) + op_sharding = mesh_2d.get_op_sharding((None, 1)) + global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding) + expected = torch.arange(self.n_devices) + self.assertTrue( + torch.allclose(global_tensor.cpu(), + torch.arange(2).reshape(1, 2))) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2db6006151e..9d3eac6c435 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1660,6 +1660,61 @@ void InitXlaModuleBindings(py::module m) { } return std::nullopt; }); + // Reassemble the CPU shards into a global tensor. A new sharded tensor is + // created from the local shards with the provided sharding annotation + // attached. The order of the shards should coincide with the order of + // devices returned by `xr.local_runtime_devices()`. + m.def( + "_global_tensor_from_cpu_shards", + [](const std::vector& shards, const xla::OpSharding& sharding, + std::optional>& global_shape) -> at::Tensor { + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + auto local_devices = runtime::GetComputationClient()->GetLocalDevices(); + XLA_CHECK(local_devices.size() == shards.size()) + << "Must specify a tensor for each local device"; + // Set the global shape according to the input, or infer the global + // shape based on the tiling and shard shape. + auto tile_shape = sharding.tile_assignment_dimensions(); + if (global_shape.has_value()) { + XLA_CHECK(global_shape->size() == shards[0].sizes().size()) + << "Shard rank must match global tensor rank, got global rank " + << global_shape->size() << " and shard rank " + << shards[0].sizes().size(); + // The global shape must be achievable through padding + for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { + auto max_size = tile_shape[dim] * shards[0].sizes()[0]; + XLA_CHECK(global_shape.value()[dim] <= max_size) + << "Invalid global shape " << global_shape.value() + << " for the provided shards and OpSharding: dimension " << dim + << " must be less than or equal to " << max_size; + } + } else { + global_shape = std::make_optional>(); + for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { + auto global_dim = tile_shape[dim] * shards[0].sizes()[0]; + global_shape->push_back(global_dim); + } + } + + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(shards[0], nullptr); + for (int dim = 0; dim < tensor_shape.rank(); ++dim) { + tensor_shape.set_dimensions(dim, global_shape.value()[dim]); + } + + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto data_handle = WrapXlaData(ShardingUtil::CreateShardedData( + shards, local_devices, sharding_spec)); + XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle)); + xla_tensor->SetShardingSpec(*sharding_spec); + auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor)); + return torch::autograd::make_variable(tensor, + shards[0].requires_grad()); + }, + py::arg("shards"), py::arg("sharding"), + py::arg("global_shape") = py::none()); // Returns the local shards of the tensor, with values taken from the // underlying ComputationClient::GetDataShards. As such, the shards will // contain any padding that was applied to ensure they all have the same diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 7be8ca1d3ae..4b9327e30fe 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -690,7 +690,8 @@ void ShardingUtil::PrepareOutputShardingPropagation( } runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( - std::vector& local_shards, std::vector& devices, + const std::vector& local_shards, + const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 8b7cc7d02f8..2d3a14a9e0a 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -144,7 +144,8 @@ class ShardingUtil { // Transfers the individual shards to the devices and returns a DataPtr for // the PjRtShardedData wrapping the shards. static runtime::ComputationClient::DataPtr CreateShardedData( - std::vector& shards, std::vector& devices, + const std::vector& shards, + const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); }; diff --git a/torch_xla/experimental/xla_sharded_tensor.py b/torch_xla/experimental/xla_sharded_tensor.py index ce423b3918f..a5a48dcf53e 100644 --- a/torch_xla/experimental/xla_sharded_tensor.py +++ b/torch_xla/experimental/xla_sharded_tensor.py @@ -102,6 +102,19 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): r.global_tensor = elem.detach() if r.requires_grad else elem return r + @staticmethod + def from_cpu_shards(shards: List[torch.Tensor], + sharding: torch_xla._XLAC.OpSharding, + global_shape: torch.Size = None): + """ + Create an XLAShardedTensor from the given list of CPU shards. The order of + the shards determines which device it will be placed on, coinciding with + the device order returned by `torch_xla.runtime.local_runtime_devices`. + """ + return XLAShardedTensor( + torch_xla._XLAC._global_tensor_from_cpu_shards(shards, sharding, + global_shape)) + # Shards on the devices are materialized/available after the lazy # execution of the partitioned HLO graph. Each XLAShard points # to torch.Tensor. The shards represent a snapshot on CPU, detached diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 3087f3c80f6..37f9a68f6c6 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -131,6 +131,12 @@ def global_device_count() -> int: return len(torch_xla._XLAC._xla_get_all_devices()) +@requires_pjrt +def local_runtime_devices() -> List[str]: + """Returns addressable devices as a list of string.""" + return torch_xla._XLAC._xla_get_runtime_devices() + + @requires_pjrt def world_size() -> int: """Returns the total number of configured logical devices.""" From 4413bc7af101ce6a2ebe8f6004764c0e4c3d91a3 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Wed, 20 Sep 2023 22:52:00 +0000 Subject: [PATCH 2/4] Handle replicated sharding --- test/spmd/test_xla_sharding.py | 6 +++ torch_xla/csrc/init_python_bindings.cpp | 55 ++++++++++++++----------- 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 5a644b9cd7b..5ea4bfddcb2 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -881,6 +881,12 @@ def test_from_cpu_shards(self): with self.assertRaises(RuntimeError): XLAShardedTensor.from_cpu_shards(shards[:-1], op_sharding) + # Test replicated sharding. The result should equal a shard. + op_sharding = mesh.get_op_sharding((None,)) + shards = [torch.LongTensor([0])] * self.n_devices + global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + # Test an invalid global shape - too many values. with self.assertRaises(RuntimeError): bad_size = torch.Size((2 * self.n_devices,)) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 9d3eac6c435..62ee529df6b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1663,37 +1663,33 @@ void InitXlaModuleBindings(py::module m) { // Reassemble the CPU shards into a global tensor. A new sharded tensor is // created from the local shards with the provided sharding annotation // attached. The order of the shards should coincide with the order of - // devices returned by `xr.local_runtime_devices()`. + // devices returned by `torch_xla.runtime.local_runtime_devices()`. m.def( "_global_tensor_from_cpu_shards", [](const std::vector& shards, const xla::OpSharding& sharding, - std::optional>& global_shape) -> at::Tensor { + std::optional>& global_shape) -> at::Tensor { XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; auto local_devices = runtime::GetComputationClient()->GetLocalDevices(); XLA_CHECK(local_devices.size() == shards.size()) - << "Must specify a tensor for each local device"; - // Set the global shape according to the input, or infer the global - // shape based on the tiling and shard shape. - auto tile_shape = sharding.tile_assignment_dimensions(); - if (global_shape.has_value()) { - XLA_CHECK(global_shape->size() == shards[0].sizes().size()) - << "Shard rank must match global tensor rank, got global rank " - << global_shape->size() << " and shard rank " - << shards[0].sizes().size(); - // The global shape must be achievable through padding - for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { - auto max_size = tile_shape[dim] * shards[0].sizes()[0]; - XLA_CHECK(global_shape.value()[dim] <= max_size) - << "Invalid global shape " << global_shape.value() - << " for the provided shards and OpSharding: dimension " << dim - << " must be less than or equal to " << max_size; - } - } else { - global_shape = std::make_optional>(); - for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { - auto global_dim = tile_shape[dim] * shards[0].sizes()[0]; - global_shape->push_back(global_dim); + << "Must specify a shard for each local device"; + + if (!global_shape.has_value()) { + // Set a default value for the global shape based on the sharding + // type. + if (sharding.type() == xla::OpSharding::OTHER) { + // Infer the global shape to be the shard shape scaled by the tiling + // dimensionality. + auto tile_shape = sharding.tile_assignment_dimensions(); + global_shape = std::vector(); + for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { + auto global_dim = tile_shape[dim] * shards[0].sizes()[dim]; + global_shape->push_back(global_dim); + } + } else { + // If the sharding type is not tiled, the tensor shards are + // replicated. + global_shape = shards[0].sizes().vec(); } } @@ -1702,9 +1698,18 @@ void InitXlaModuleBindings(py::module m) { for (int dim = 0; dim < tensor_shape.rank(); ++dim) { tensor_shape.set_dimensions(dim, global_shape.value()[dim]); } - auto sharding_spec = std::make_shared(sharding, tensor_shape); + + // Verify that the shard shape is correct for the global shape and + // sharding spec. + auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec); + for (auto shard : shards) { + XLA_CHECK(shard.sizes() == expected_shard_shape) + << "Input shard shape must include padding: " << shard.sizes() + << " vs " << expected_shard_shape; + } + auto data_handle = WrapXlaData(ShardingUtil::CreateShardedData( shards, local_devices, sharding_spec)); XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle)); From bf329c7b0672236ade09b2ec10c05a19e6b0bbd5 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Thu, 21 Sep 2023 01:31:09 +0000 Subject: [PATCH 3/4] Move validations into get_op_sharding --- torch_xla/experimental/xla_sharding.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 95f4a88128b..21d0e2e570a 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -87,6 +87,14 @@ def get_op_sharding(self, Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. """ + partition_spec = _translate_named_partition_spec(self, partition_spec) + flat_specs = np.hstack([d for d in partition_spec]) + specs = [d for d in flat_specs if d is not None] + assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ + f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." + assert len(specs) == len(np.unique(specs)), \ + f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." + tile_assignment = _get_tile_assignment(self, partition_spec) if len(tile_assignment.shape) > len(partition_spec): # Use partial replication for sharding a tensor over a higher-rank mesh @@ -482,19 +490,12 @@ def mark_sharding( assert num_devices > 0, "This requires XLA supported device(s)." assert mesh.size() == num_devices, \ f"{mesh.mesh_shape} is not mappable over {num_devices} devices." - partition_spec = _translate_named_partition_spec(mesh, partition_spec) # We only allow fully specified `partition_spec` to be applicable, as opposed # to filling in the unspecified replicated dims. Fully specified `partiion_spec` # should be of the same rank as `t`. This is to support partial replication # where the group assignment may vary with different input ranks. assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - flat_specs = np.hstack([d for d in partition_spec]) - specs = [d for d in flat_specs if d is not None] - assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \ - f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." - assert len(specs) == len(np.unique(specs)), \ - f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." op_sharding = mesh.get_op_sharding(partition_spec) From c03a5672354c0fc088812dbd4031e6cbafeb12d5 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Tue, 26 Sep 2023 02:14:53 +0000 Subject: [PATCH 4/4] Improve tests and error handling --- test/spmd/test_xla_sharding.py | 156 +++++++++++++------ torch_xla/csrc/init_python_bindings.cpp | 27 ++-- torch_xla/experimental/xla_sharded_tensor.py | 13 -- torch_xla/runtime.py | 6 - 4 files changed, 129 insertions(+), 73 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 5ea4bfddcb2..512f048cf65 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -866,70 +866,138 @@ def test_op_sharding_cache(self): xs.mark_sharding(v, mesh, (0, None)) self.assertEqual(met.counter_value("CreateOpSharding"), 2) - def test_from_cpu_shards(self): + def test_from_cpu_shards_replicated(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + # Create an OpSharding with all devices on a single axis mesh = self._get_mesh((self.n_devices,)) - op_sharding = mesh.get_op_sharding((0,)) + partition_spec = (None,) + op_sharding = mesh.get_op_sharding(partition_spec) + shards = [torch.arange(4)] * self.n_devices + + # No shape should result in the shape of a single shard. + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + # Specify a valid shape for the global tensor + global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + # All invalid shapes should raise + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((5,))) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((3,))) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((2, 2))) + + def test_from_cpu_shards_tiled(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards - # Infer the global shape from the OpSharding + # Create an OpSharding with all devices on a single axis + mesh = self._get_mesh((self.n_devices,)) + partition_spec = (0,) + op_sharding = mesh.get_op_sharding(partition_spec) shards = [torch.LongTensor([i]) for i in range(self.n_devices)] - global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding) + + global_tensor = from_cpu_shards(shards, op_sharding) self.assertTrue( torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices))) # Test incorrect number of shards with self.assertRaises(RuntimeError): - XLAShardedTensor.from_cpu_shards(shards[:-1], op_sharding) - - # Test replicated sharding. The result should equal a shard. - op_sharding = mesh.get_op_sharding((None,)) - shards = [torch.LongTensor([0])] * self.n_devices - global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding) - self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + from_cpu_shards(shards[:-1], op_sharding) # Test an invalid global shape - too many values. with self.assertRaises(RuntimeError): - bad_size = torch.Size((2 * self.n_devices,)) - XLAShardedTensor.from_cpu_shards(shards, op_sharding, bad_size) + from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,))) + + # Test an invalid global shape - incorrect rank + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices))) + + # Test a valid global shape - restrict the number of meaningful values + # to 1, treating the rest as padding. + global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,))) + self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1))) + + def test_from_cpu_shards_2d(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an appropriate 2D mesh for the number of devices + if self.n_devices >= 4: + mesh_shape = (self.n_devices // 2, 2) + else: + mesh_shape = (1, self.n_devices) + mesh_2d = self._get_mesh(mesh_shape) + + # Replicated sharding + shards = [torch.LongTensor([self.n_devices])] * self.n_devices + partition_spec = (None, None) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) if self.n_devices > 1: - # Test an invalid global shape - incorrect rank - with self.assertRaises(RuntimeError): - bad_size = torch.Size((self.n_devices // 2, 2)) - XLAShardedTensor.from_cpu_shards(shards, op_sharding, bad_size) - - # Test a valid global shape - restrict the number of meaningful values - # to 1, treating the rest as padding. - global_shape = torch.Size((1,)) - global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding, - global_shape) - self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1))) - - # Use a 2d mesh - mesh_2d = self._get_mesh((self.n_devices // 2, 2)) + # Tiled sharding shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)] - op_sharding = mesh_2d.get_op_sharding((0, 1)) - global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding) + partition_spec = (0, 1) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) expected = torch.arange(self.n_devices).reshape(2, self.n_devices // 2) self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) - # Use partial replication. With eight devices, the mesh and shards are: - # mesh_2d shards - # | TPU:0 | TPU:1 | | 0 | 1 | - # | TPU:2 | TPU:3 | | 0 | 1 | - # | TPU:4 | TPU:5 | | 0 | 1 | - # | TPU:6 | TPU:7 | | 0 | 1 | - # - # Partial replication along the 0th axis represents a global tensor - # of torch.Tensor([[0, 1]]). + # Partially replicated sharding shards = [torch.LongTensor([[i]]) for i in range(2)] * ( self.n_devices // 2) - op_sharding = mesh_2d.get_op_sharding((None, 1)) - global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding) - expected = torch.arange(self.n_devices) - self.assertTrue( - torch.allclose(global_tensor.cpu(), - torch.arange(2).reshape(1, 2))) + partition_spec = (None, 1) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + # Partial replication along the 0th axis represents a global tensor + # of torch.Tensor([[0, 1]]). + expected = torch.arange(2).reshape(1, 2) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + def test_from_cpu_shards_global_shape(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + mesh = self._get_mesh((self.n_devices,)) + numel = self.n_devices**2 + # The global tensor is torch.arange(numel). + shards = [ + torch.arange(self.n_devices) + (i * self.n_devices) + for i in range(self.n_devices) + ] + partition_spec = (0,) + op_sharding = mesh.get_op_sharding(partition_spec) + + # No global shape specified will include all data from the shards + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel))) + + # Too large of a global shape will error out + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,))) + + if self.n_devices > 1: + # When the global tensor has fewer elements than the sum of its shards, + # there are two cases: + + # Case 1: If the global shape is within n_devices of numel, the excess + # data is treated as padding and ignored. + for delta in range(self.n_devices): + size = torch.Size((numel - delta,)) + global_tensor = from_cpu_shards(shards, op_sharding, size) + expected = torch.arange(size[0]) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + # Case 2: Otherwise, it is not possible to have that much padding in a + # sharded tensor, and the shards are incompatible with the shape. + with self.assertRaises(RuntimeError): + shape = torch.Size((numel - self.n_devices,)) + from_cpu_shards(shards, op_sharding, shape) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((1,))) if __name__ == '__main__': diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 62ee529df6b..c796e662fcf 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -33,6 +33,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" +#include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" @@ -1667,12 +1668,17 @@ void InitXlaModuleBindings(py::module m) { m.def( "_global_tensor_from_cpu_shards", [](const std::vector& shards, const xla::OpSharding& sharding, - std::optional>& global_shape) -> at::Tensor { + std::optional>& global_shape) -> at::Tensor { XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; auto local_devices = runtime::GetComputationClient()->GetLocalDevices(); XLA_CHECK(local_devices.size() == shards.size()) << "Must specify a shard for each local device"; + XLA_CHECK(!global_shape.has_value() || + global_shape.value().size() == shards[0].sizes().size()) + << "Global shape rank must agree with shard rank: expected rank " + << shards[0].sizes().size() << ", got " + << global_shape.value().size(); if (!global_shape.has_value()) { // Set a default value for the global shape based on the sharding @@ -1681,23 +1687,24 @@ void InitXlaModuleBindings(py::module m) { // Infer the global shape to be the shard shape scaled by the tiling // dimensionality. auto tile_shape = sharding.tile_assignment_dimensions(); - global_shape = std::vector(); + global_shape = std::vector(); for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { auto global_dim = tile_shape[dim] * shards[0].sizes()[dim]; global_shape->push_back(global_dim); } - } else { - // If the sharding type is not tiled, the tensor shards are - // replicated. + } else if (sharding.type() == xla::OpSharding::REPLICATED) { global_shape = shards[0].sizes().vec(); + } else { + XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type(); } } - xla::Shape tensor_shape = - CreateComputationShapeFromTensor(shards[0], nullptr); - for (int dim = 0; dim < tensor_shape.rank(); ++dim) { - tensor_shape.set_dimensions(dim, global_shape.value()[dim]); - } + auto device = GetVirtualDevice(); + auto primitive_type = + MakeXlaPrimitiveType(shards[0].type().scalarType(), &device); + xla::Shape tensor_shape = MakeArrayShapeFromDimensions( + global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type, + static_cast(device.type())); auto sharding_spec = std::make_shared(sharding, tensor_shape); diff --git a/torch_xla/experimental/xla_sharded_tensor.py b/torch_xla/experimental/xla_sharded_tensor.py index a5a48dcf53e..ce423b3918f 100644 --- a/torch_xla/experimental/xla_sharded_tensor.py +++ b/torch_xla/experimental/xla_sharded_tensor.py @@ -102,19 +102,6 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): r.global_tensor = elem.detach() if r.requires_grad else elem return r - @staticmethod - def from_cpu_shards(shards: List[torch.Tensor], - sharding: torch_xla._XLAC.OpSharding, - global_shape: torch.Size = None): - """ - Create an XLAShardedTensor from the given list of CPU shards. The order of - the shards determines which device it will be placed on, coinciding with - the device order returned by `torch_xla.runtime.local_runtime_devices`. - """ - return XLAShardedTensor( - torch_xla._XLAC._global_tensor_from_cpu_shards(shards, sharding, - global_shape)) - # Shards on the devices are materialized/available after the lazy # execution of the partitioned HLO graph. Each XLAShard points # to torch.Tensor. The shards represent a snapshot on CPU, detached diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 37f9a68f6c6..3087f3c80f6 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -131,12 +131,6 @@ def global_device_count() -> int: return len(torch_xla._XLAC._xla_get_all_devices()) -@requires_pjrt -def local_runtime_devices() -> List[str]: - """Returns addressable devices as a list of string.""" - return torch_xla._XLAC._xla_get_runtime_devices() - - @requires_pjrt def world_size() -> int: """Returns the total number of configured logical devices."""