Skip to content

Commit

Permalink
Add API to assemble CPU shards to a sharded tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Sep 20, 2023
1 parent 2fda248 commit 34cc1d3
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 2 deletions.
59 changes: 59 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,65 @@ def test_op_sharding_cache(self):
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)

def test_from_cpu_shards(self):
# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
op_sharding = mesh.get_op_sharding((0,))

# Infer the global shape from the OpSharding
shards = [torch.LongTensor([i]) for i in range(self.n_devices)]
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding)
self.assertTrue(
torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices)))

# Test incorrect number of shards
with self.assertRaises(RuntimeError):
XLAShardedTensor.from_cpu_shards(shards[:-1], op_sharding)

# Test an invalid global shape - too many values.
with self.assertRaises(RuntimeError):
bad_size = torch.Size((2 * self.n_devices,))
XLAShardedTensor.from_cpu_shards(shards, op_sharding, bad_size)

if self.n_devices > 1:
# Test an invalid global shape - incorrect rank
with self.assertRaises(RuntimeError):
bad_size = torch.Size((self.n_devices // 2, 2))
XLAShardedTensor.from_cpu_shards(shards, op_sharding, bad_size)

# Test a valid global shape - restrict the number of meaningful values
# to 1, treating the rest as padding.
global_shape = torch.Size((1,))
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding,
global_shape)
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1)))

# Use a 2d mesh
mesh_2d = self._get_mesh((self.n_devices // 2, 2))
shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)]
op_sharding = mesh_2d.get_op_sharding((0, 1))
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding)
expected = torch.arange(self.n_devices).reshape(2, self.n_devices // 2)
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

# Use partial replication. With eight devices, the mesh and shards are:
# mesh_2d shards
# | TPU:0 | TPU:1 | | 0 | 1 |
# | TPU:2 | TPU:3 | | 0 | 1 |
# | TPU:4 | TPU:5 | | 0 | 1 |
# | TPU:6 | TPU:7 | | 0 | 1 |
#
# Partial replication along the 0th axis represents a global tensor
# of torch.Tensor([[0, 1]]).
shards = [torch.LongTensor([[i]]) for i in range(2)] * (
self.n_devices // 2)
op_sharding = mesh_2d.get_op_sharding((None, 1))
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding)
expected = torch.arange(self.n_devices)
self.assertTrue(
torch.allclose(global_tensor.cpu(),
torch.arange(2).reshape(1, 2)))


if __name__ == '__main__':
test = unittest.main()
Expand Down
50 changes: 50 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,56 @@ void InitXlaModuleBindings(py::module m) {
}
return std::nullopt;
});
// Reassemble the CPU shards into a global tensor. A new sharded tensor is
// created from the local shards with the provided sharding annotation
// attached. The order of the shards should coincide with the order of
// devices returned by `xr.local_runtime_devices()`.
m.def(
"_global_tensor_from_cpu_shards",
[](const std::vector<at::Tensor>& shards, const xla::OpSharding& sharding,
std::optional<std::vector<int>>& global_shape) -> at::Tensor {
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
auto local_devices = runtime::GetComputationClient()->GetLocalDevices();
XLA_CHECK(local_devices.size() == shards.size())
<< "Must specify a tensor for each local device";
// Set the global shape according to the input, or infer the global
// shape based on the tiling and shard shape.
auto tile_shape = sharding.tile_assignment_dimensions();
if (global_shape.has_value()) {
XLA_CHECK(global_shape->size() == shards[0].sizes().size())
<< "Shard rank must match global tensor rank, got global rank "
<< global_shape->size() << " and shard rank "
<< shards[0].sizes().size();
// The global shape must be achievable through padding
for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
auto max_size = tile_shape[dim] * shards[0].sizes()[0];
XLA_CHECK(global_shape.value()[dim] <= max_size) << "Invalid global shape " << global_shape.value() << " for the provided shards and OpSharding: dimension " << dim << " must be less than or equal to " << max_size;
}
} else {
global_shape = std::make_optional<std::vector<int>>();
for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
auto global_dim = tile_shape[dim] * shards[0].sizes()[0];
global_shape->push_back(global_dim);
}
}

xla::Shape tensor_shape =
CreateComputationShapeFromTensor(shards[0], nullptr);
for (int dim = 0; dim < tensor_shape.rank(); ++dim) {
tensor_shape.set_dimensions(dim, global_shape.value()[dim]);
}

auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
auto data_handle = WrapXlaData(ShardingUtil::CreateShardedData(
shards, local_devices, sharding_spec));
XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle));
xla_tensor->SetShardingSpec(*sharding_spec);
auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor));
return torch::autograd::make_variable(tensor,
shards[0].requires_grad());
}, py::arg("shards"), py::arg("sharding"), py::arg("global_shape") = py::none());
// Returns the local shards of the tensor, with values taken from the
// underlying ComputationClient::GetDataShards. As such, the shards will
// contain any padding that was applied to ensure they all have the same
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ void ShardingUtil::PrepareOutputShardingPropagation(
}

runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
std::vector<at::Tensor>& local_shards, std::vector<std::string>& devices,
const std::vector<at::Tensor>& local_shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec) {
XLA_CHECK(local_shards.size() == devices.size())
<< "A device must be speficied for each shard";
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ class ShardingUtil {
// Transfers the individual shards to the devices and returns a DataPtr for
// the PjRtShardedData wrapping the shards.
static runtime::ComputationClient::DataPtr CreateShardedData(
std::vector<at::Tensor>& shards, std::vector<std::string>& devices,
const std::vector<at::Tensor>& shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);
};

Expand Down
13 changes: 13 additions & 0 deletions torch_xla/experimental/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
r.global_tensor = elem.detach() if r.requires_grad else elem
return r

@staticmethod
def from_cpu_shards(shards: List[torch.Tensor],
sharding: torch_xla._XLAC.OpSharding,
global_shape: torch.Size = None):
"""
Create an XLAShardedTensor from the given list of CPU shards. The order of
the shards determines which device it will be placed on, coinciding with
the device order returned by `torch_xla.runtime.local_runtime_devices`.
"""
return XLAShardedTensor(
torch_xla._XLAC._global_tensor_from_cpu_shards(shards, sharding,
global_shape))

# Shards on the devices are materialized/available after the lazy
# execution of the partitioned HLO graph. Each XLAShard points
# to torch.Tensor. The shards represent a snapshot on CPU, detached
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def global_device_count() -> int:
return len(torch_xla._XLAC._xla_get_all_devices())


@requires_pjrt
def local_runtime_devices() -> List[str]:
"""Returns addressable devices as a list of string."""
return torch_xla._XLAC._xla_get_runtime_devices()


@requires_pjrt
def world_size() -> int:
"""Returns the total number of configured logical devices."""
Expand Down

0 comments on commit 34cc1d3

Please sign in to comment.