Skip to content

Commit

Permalink
Handle replicated sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Sep 20, 2023
1 parent 13cfb7a commit 4413bc7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
6 changes: 6 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,12 @@ def test_from_cpu_shards(self):
with self.assertRaises(RuntimeError):
XLAShardedTensor.from_cpu_shards(shards[:-1], op_sharding)

# Test replicated sharding. The result should equal a shard.
op_sharding = mesh.get_op_sharding((None,))
shards = [torch.LongTensor([0])] * self.n_devices
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

# Test an invalid global shape - too many values.
with self.assertRaises(RuntimeError):
bad_size = torch.Size((2 * self.n_devices,))
Expand Down
55 changes: 30 additions & 25 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1663,37 +1663,33 @@ void InitXlaModuleBindings(py::module m) {
// Reassemble the CPU shards into a global tensor. A new sharded tensor is
// created from the local shards with the provided sharding annotation
// attached. The order of the shards should coincide with the order of
// devices returned by `xr.local_runtime_devices()`.
// devices returned by `torch_xla.runtime.local_runtime_devices()`.
m.def(
"_global_tensor_from_cpu_shards",
[](const std::vector<at::Tensor>& shards, const xla::OpSharding& sharding,
std::optional<std::vector<int>>& global_shape) -> at::Tensor {
std::optional<std::vector<long int>>& global_shape) -> at::Tensor {
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
auto local_devices = runtime::GetComputationClient()->GetLocalDevices();
XLA_CHECK(local_devices.size() == shards.size())
<< "Must specify a tensor for each local device";
// Set the global shape according to the input, or infer the global
// shape based on the tiling and shard shape.
auto tile_shape = sharding.tile_assignment_dimensions();
if (global_shape.has_value()) {
XLA_CHECK(global_shape->size() == shards[0].sizes().size())
<< "Shard rank must match global tensor rank, got global rank "
<< global_shape->size() << " and shard rank "
<< shards[0].sizes().size();
// The global shape must be achievable through padding
for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
auto max_size = tile_shape[dim] * shards[0].sizes()[0];
XLA_CHECK(global_shape.value()[dim] <= max_size)
<< "Invalid global shape " << global_shape.value()
<< " for the provided shards and OpSharding: dimension " << dim
<< " must be less than or equal to " << max_size;
}
} else {
global_shape = std::make_optional<std::vector<int>>();
for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
auto global_dim = tile_shape[dim] * shards[0].sizes()[0];
global_shape->push_back(global_dim);
<< "Must specify a shard for each local device";

if (!global_shape.has_value()) {
// Set a default value for the global shape based on the sharding
// type.
if (sharding.type() == xla::OpSharding::OTHER) {
// Infer the global shape to be the shard shape scaled by the tiling
// dimensionality.
auto tile_shape = sharding.tile_assignment_dimensions();
global_shape = std::vector<long int>();
for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
auto global_dim = tile_shape[dim] * shards[0].sizes()[dim];
global_shape->push_back(global_dim);
}
} else {
// If the sharding type is not tiled, the tensor shards are
// replicated.
global_shape = shards[0].sizes().vec();
}
}

Expand All @@ -1702,9 +1698,18 @@ void InitXlaModuleBindings(py::module m) {
for (int dim = 0; dim < tensor_shape.rank(); ++dim) {
tensor_shape.set_dimensions(dim, global_shape.value()[dim]);
}

auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);

// Verify that the shard shape is correct for the global shape and
// sharding spec.
auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec);
for (auto shard : shards) {
XLA_CHECK(shard.sizes() == expected_shard_shape)
<< "Input shard shape must include padding: " << shard.sizes()
<< " vs " << expected_shard_shape;
}

auto data_handle = WrapXlaData(ShardingUtil::CreateShardedData(
shards, local_devices, sharding_spec));
XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle));
Expand Down

0 comments on commit 4413bc7

Please sign in to comment.