diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index ce2cae18dd6c..1b128164a22b 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -900,6 +900,139 @@ def test_op_sharding_cache(self): xs.mark_sharding(v, mesh, (0, None)) self.assertEqual(met.counter_value("CreateOpSharding"), 2) + def test_from_cpu_shards_replicated(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an OpSharding with all devices on a single axis + mesh = self._get_mesh((self.n_devices,)) + partition_spec = (None,) + op_sharding = mesh.get_op_sharding(partition_spec) + shards = [torch.arange(4)] * self.n_devices + + # No shape should result in the shape of a single shard. + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + # Specify a valid shape for the global tensor + global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + # All invalid shapes should raise + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((5,))) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((3,))) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((2, 2))) + + def test_from_cpu_shards_tiled(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an OpSharding with all devices on a single axis + mesh = self._get_mesh((self.n_devices,)) + partition_spec = (0,) + op_sharding = mesh.get_op_sharding(partition_spec) + shards = [torch.LongTensor([i]) for i in range(self.n_devices)] + + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue( + torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices))) + + # Test incorrect number of shards + with self.assertRaises(RuntimeError): + from_cpu_shards(shards[:-1], op_sharding) + + # Test an invalid global shape - too many values. + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,))) + + # Test an invalid global shape - incorrect rank + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices))) + + # Test a valid global shape - restrict the number of meaningful values + # to 1, treating the rest as padding. + global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,))) + self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1))) + + def test_from_cpu_shards_2d(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an appropriate 2D mesh for the number of devices + if self.n_devices >= 4: + mesh_shape = (self.n_devices // 2, 2) + else: + mesh_shape = (1, self.n_devices) + mesh_2d = self._get_mesh(mesh_shape) + + # Replicated sharding + shards = [torch.LongTensor([self.n_devices])] * self.n_devices + partition_spec = (None, None) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + if self.n_devices > 1: + # Tiled sharding + shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)] + partition_spec = (0, 1) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + expected = torch.arange(self.n_devices).reshape(*mesh_shape) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + # Partially replicated sharding + shards = [torch.LongTensor([[i]]) for i in range(2)] * ( + self.n_devices // 2) + partition_spec = (None, 1) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + # Partial replication along the 0th axis represents a global tensor + # of torch.Tensor([[0, 1]]). + expected = torch.arange(2).reshape(1, 2) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + def test_from_cpu_shards_global_shape(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + mesh = self._get_mesh((self.n_devices,)) + numel = self.n_devices**2 + # The global tensor is torch.arange(numel). + shards = [ + torch.arange(self.n_devices) + (i * self.n_devices) + for i in range(self.n_devices) + ] + partition_spec = (0,) + op_sharding = mesh.get_op_sharding(partition_spec) + + # No global shape specified will include all data from the shards + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel))) + + # Too large of a global shape will error out + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,))) + + if self.n_devices > 1: + # When the global tensor has fewer elements than the sum of its shards, + # there are two cases: + + # Case 1: If the global shape is within n_devices of numel, the excess + # data is treated as padding and ignored. + for delta in range(self.n_devices): + size = torch.Size((numel - delta,)) + global_tensor = from_cpu_shards(shards, op_sharding, size) + expected = torch.arange(size[0]) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + # Case 2: Otherwise, it is not possible to have that much padding in a + # sharded tensor, and the shards are incompatible with the shape. + with self.assertRaises(RuntimeError): + shape = torch.Size((numel - self.n_devices,)) + from_cpu_shards(shards, op_sharding, shape) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((1,))) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b3957d7a68f7..421066ba72cd 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -34,6 +34,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" +#include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" @@ -1663,6 +1664,72 @@ void InitXlaModuleBindings(py::module m) { } return std::nullopt; }); + // Reassemble the CPU shards into a global tensor. A new sharded tensor is + // created from the local shards with the provided sharding annotation + // attached. The order of the shards should coincide with the order of + // devices returned by `torch_xla.runtime.local_runtime_devices()`. + m.def( + "_global_tensor_from_cpu_shards", + [](const std::vector& shards, const xla::OpSharding& sharding, + std::optional>& global_shape) -> at::Tensor { + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + auto local_devices = runtime::GetComputationClient()->GetLocalDevices(); + XLA_CHECK(local_devices.size() == shards.size()) + << "Must specify a shard for each local device"; + XLA_CHECK(!global_shape.has_value() || + global_shape.value().size() == shards[0].sizes().size()) + << "Global shape rank must agree with shard rank: expected rank " + << shards[0].sizes().size() << ", got " + << global_shape.value().size(); + + if (!global_shape.has_value()) { + // Set a default value for the global shape based on the sharding + // type. + if (sharding.type() == xla::OpSharding::OTHER) { + // Infer the global shape to be the shard shape scaled by the tiling + // dimensionality. + auto tile_shape = sharding.tile_assignment_dimensions(); + global_shape = std::vector(); + for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { + auto global_dim = tile_shape[dim] * shards[0].sizes()[dim]; + global_shape->push_back(global_dim); + } + } else if (sharding.type() == xla::OpSharding::REPLICATED) { + global_shape = shards[0].sizes().vec(); + } else { + XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type(); + } + } + + auto device = GetVirtualDevice(); + auto primitive_type = + MakeXlaPrimitiveType(shards[0].type().scalarType(), &device); + xla::Shape tensor_shape = MakeArrayShapeFromDimensions( + global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type, + static_cast(device.type())); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + + // Verify that the shard shape is correct for the global shape and + // sharding spec. + auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec); + for (auto shard : shards) { + XLA_CHECK(shard.sizes() == expected_shard_shape) + << "Input shard shape must include padding: " << shard.sizes() + << " vs " << expected_shard_shape; + } + + auto data_handle = ShardingUtil::CreateShardedData( + shards, local_devices, sharding_spec); + XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle)); + xla_tensor->SetShardingSpec(*sharding_spec); + auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor)); + return torch::autograd::make_variable(tensor, + shards[0].requires_grad()); + }, + py::arg("shards"), py::arg("sharding"), + py::arg("global_shape") = py::none()); // Returns the local shards of the tensor, with values taken from the // underlying ComputationClient::GetDataShards. As such, the shards will // contain any padding that was applied to ensure they all have the same diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index f7da463fb647..cde74256eeee 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -706,7 +706,8 @@ void ShardingUtil::PrepareOutputShardingPropagation( } runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( - std::vector& local_shards, std::vector& devices, + const std::vector& local_shards, + const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 4a595f4e99b0..32060c7fc098 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -147,7 +147,8 @@ class ShardingUtil { // Transfers the individual shards to the devices and returns a DataPtr for // the PjRtShardedData wrapping the shards. static runtime::ComputationClient::DataPtr CreateShardedData( - std::vector& shards, std::vector& devices, + const std::vector& shards, + const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); }; diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 95f4a88128bb..21d0e2e570ac 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -87,6 +87,14 @@ def get_op_sharding(self, Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. """ + partition_spec = _translate_named_partition_spec(self, partition_spec) + flat_specs = np.hstack([d for d in partition_spec]) + specs = [d for d in flat_specs if d is not None] + assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ + f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." + assert len(specs) == len(np.unique(specs)), \ + f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." + tile_assignment = _get_tile_assignment(self, partition_spec) if len(tile_assignment.shape) > len(partition_spec): # Use partial replication for sharding a tensor over a higher-rank mesh @@ -482,19 +490,12 @@ def mark_sharding( assert num_devices > 0, "This requires XLA supported device(s)." assert mesh.size() == num_devices, \ f"{mesh.mesh_shape} is not mappable over {num_devices} devices." - partition_spec = _translate_named_partition_spec(mesh, partition_spec) # We only allow fully specified `partition_spec` to be applicable, as opposed # to filling in the unspecified replicated dims. Fully specified `partiion_spec` # should be of the same rank as `t`. This is to support partial replication # where the group assignment may vary with different input ranks. assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - flat_specs = np.hstack([d for d in partition_spec]) - specs = [d for d in flat_specs if d is not None] - assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \ - f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." - assert len(specs) == len(np.unique(specs)), \ - f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." op_sharding = mesh.get_op_sharding(partition_spec)