Skip to content

Commit

Permalink
Add API to assemble CPU shards to a sharded tensor (#5681)
Browse files Browse the repository at this point in the history
* Add API to assemble CPU shards to a sharded tensor

* Handle replicated sharding

* Move validations into get_op_sharding

* Improve tests and error handling

* Don't WrapXlaData

* Fix test for v3
  • Loading branch information
jonb377 authored and golechwierowicz committed Jan 12, 2024
1 parent 6e67f37 commit 6570258
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 9 deletions.
133 changes: 133 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,139 @@ def test_op_sharding_cache(self):
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)

def test_from_cpu_shards_replicated(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
partition_spec = (None,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.arange(4)] * self.n_devices

# No shape should result in the shape of a single shard.
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

# Specify a valid shape for the global tensor
global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

# All invalid shapes should raise
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((5,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((3,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((2, 2)))

def test_from_cpu_shards_tiled(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
partition_spec = (0,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.LongTensor([i]) for i in range(self.n_devices)]

global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(
torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices)))

# Test incorrect number of shards
with self.assertRaises(RuntimeError):
from_cpu_shards(shards[:-1], op_sharding)

# Test an invalid global shape - too many values.
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,)))

# Test an invalid global shape - incorrect rank
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices)))

# Test a valid global shape - restrict the number of meaningful values
# to 1, treating the rest as padding.
global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,)))
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1)))

def test_from_cpu_shards_2d(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an appropriate 2D mesh for the number of devices
if self.n_devices >= 4:
mesh_shape = (self.n_devices // 2, 2)
else:
mesh_shape = (1, self.n_devices)
mesh_2d = self._get_mesh(mesh_shape)

# Replicated sharding
shards = [torch.LongTensor([self.n_devices])] * self.n_devices
partition_spec = (None, None)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

if self.n_devices > 1:
# Tiled sharding
shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)]
partition_spec = (0, 1)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
expected = torch.arange(self.n_devices).reshape(*mesh_shape)
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

# Partially replicated sharding
shards = [torch.LongTensor([[i]]) for i in range(2)] * (
self.n_devices // 2)
partition_spec = (None, 1)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
# Partial replication along the 0th axis represents a global tensor
# of torch.Tensor([[0, 1]]).
expected = torch.arange(2).reshape(1, 2)
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

def test_from_cpu_shards_global_shape(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

mesh = self._get_mesh((self.n_devices,))
numel = self.n_devices**2
# The global tensor is torch.arange(numel).
shards = [
torch.arange(self.n_devices) + (i * self.n_devices)
for i in range(self.n_devices)
]
partition_spec = (0,)
op_sharding = mesh.get_op_sharding(partition_spec)

# No global shape specified will include all data from the shards
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel)))

# Too large of a global shape will error out
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,)))

if self.n_devices > 1:
# When the global tensor has fewer elements than the sum of its shards,
# there are two cases:

# Case 1: If the global shape is within n_devices of numel, the excess
# data is treated as padding and ignored.
for delta in range(self.n_devices):
size = torch.Size((numel - delta,))
global_tensor = from_cpu_shards(shards, op_sharding, size)
expected = torch.arange(size[0])
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

# Case 2: Otherwise, it is not possible to have that much padding in a
# sharded tensor, and the shards are incompatible with the shape.
with self.assertRaises(RuntimeError):
shape = torch.Size((numel - self.n_devices,))
from_cpu_shards(shards, op_sharding, shape)
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1,)))


if __name__ == '__main__':
test = unittest.main()
Expand Down
67 changes: 67 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/computation_client.h"
Expand Down Expand Up @@ -1663,6 +1664,72 @@ void InitXlaModuleBindings(py::module m) {
}
return std::nullopt;
});
// Reassemble the CPU shards into a global tensor. A new sharded tensor is
// created from the local shards with the provided sharding annotation
// attached. The order of the shards should coincide with the order of
// devices returned by `torch_xla.runtime.local_runtime_devices()`.
m.def(
"_global_tensor_from_cpu_shards",
[](const std::vector<at::Tensor>& shards, const xla::OpSharding& sharding,
std::optional<std::vector<int64_t>>& global_shape) -> at::Tensor {
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
auto local_devices = runtime::GetComputationClient()->GetLocalDevices();
XLA_CHECK(local_devices.size() == shards.size())
<< "Must specify a shard for each local device";
XLA_CHECK(!global_shape.has_value() ||
global_shape.value().size() == shards[0].sizes().size())
<< "Global shape rank must agree with shard rank: expected rank "
<< shards[0].sizes().size() << ", got "
<< global_shape.value().size();

if (!global_shape.has_value()) {
// Set a default value for the global shape based on the sharding
// type.
if (sharding.type() == xla::OpSharding::OTHER) {
// Infer the global shape to be the shard shape scaled by the tiling
// dimensionality.
auto tile_shape = sharding.tile_assignment_dimensions();
global_shape = std::vector<int64_t>();
for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
auto global_dim = tile_shape[dim] * shards[0].sizes()[dim];
global_shape->push_back(global_dim);
}
} else if (sharding.type() == xla::OpSharding::REPLICATED) {
global_shape = shards[0].sizes().vec();
} else {
XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type();
}
}

auto device = GetVirtualDevice();
auto primitive_type =
MakeXlaPrimitiveType(shards[0].type().scalarType(), &device);
xla::Shape tensor_shape = MakeArrayShapeFromDimensions(
global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type,
static_cast<XlaDeviceType>(device.type()));
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);

// Verify that the shard shape is correct for the global shape and
// sharding spec.
auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec);
for (auto shard : shards) {
XLA_CHECK(shard.sizes() == expected_shard_shape)
<< "Input shard shape must include padding: " << shard.sizes()
<< " vs " << expected_shard_shape;
}

auto data_handle = ShardingUtil::CreateShardedData(
shards, local_devices, sharding_spec);
XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle));
xla_tensor->SetShardingSpec(*sharding_spec);
auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor));
return torch::autograd::make_variable(tensor,
shards[0].requires_grad());
},
py::arg("shards"), py::arg("sharding"),
py::arg("global_shape") = py::none());
// Returns the local shards of the tensor, with values taken from the
// underlying ComputationClient::GetDataShards. As such, the shards will
// contain any padding that was applied to ensure they all have the same
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,8 @@ void ShardingUtil::PrepareOutputShardingPropagation(
}

runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
std::vector<at::Tensor>& local_shards, std::vector<std::string>& devices,
const std::vector<at::Tensor>& local_shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec) {
XLA_CHECK(local_shards.size() == devices.size())
<< "A device must be speficied for each shard";
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ class ShardingUtil {
// Transfers the individual shards to the devices and returns a DataPtr for
// the PjRtShardedData wrapping the shards.
static runtime::ComputationClient::DataPtr CreateShardedData(
std::vector<at::Tensor>& shards, std::vector<std::string>& devices,
const std::vector<at::Tensor>& shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);
};

Expand Down
15 changes: 8 additions & 7 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def get_op_sharding(self,
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
"""
partition_spec = _translate_named_partition_spec(self, partition_spec)
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(self, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
Expand Down Expand Up @@ -482,19 +490,12 @@ def mark_sharding(
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
partition_spec = _translate_named_partition_spec(mesh, partition_spec)
# We only allow fully specified `partition_spec` to be applicable, as opposed
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
# should be of the same rank as `t`. This is to support partial replication
# where the group assignment may vary with different input ranks.
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

op_sharding = mesh.get_op_sharding(partition_spec)

Expand Down

0 comments on commit 6570258

Please sign in to comment.