Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to assemble CPU shards to a sharded tensor #5630

Merged
merged 4 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,139 @@ def test_op_sharding_cache(self):
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)

def test_from_cpu_shards_replicated(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
partition_spec = (None,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.arange(4)] * self.n_devices

# No shape should result in the shape of a single shard.
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

# Specify a valid shape for the global tensor
global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

# All invalid shapes should raise
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((5,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((3,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((2, 2)))

def test_from_cpu_shards_tiled(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
partition_spec = (0,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.LongTensor([i]) for i in range(self.n_devices)]

global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(
torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices)))

# Test incorrect number of shards
with self.assertRaises(RuntimeError):
from_cpu_shards(shards[:-1], op_sharding)

# Test an invalid global shape - too many values.
jonb377 marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,)))

# Test an invalid global shape - incorrect rank
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices)))

# Test a valid global shape - restrict the number of meaningful values
# to 1, treating the rest as padding.
global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,)))
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1)))

def test_from_cpu_shards_2d(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an appropriate 2D mesh for the number of devices
if self.n_devices >= 4:
mesh_shape = (self.n_devices // 2, 2)
else:
mesh_shape = (1, self.n_devices)
mesh_2d = self._get_mesh(mesh_shape)

# Replicated sharding
shards = [torch.LongTensor([self.n_devices])] * self.n_devices
partition_spec = (None, None)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

if self.n_devices > 1:
# Tiled sharding
shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)]
partition_spec = (0, 1)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
expected = torch.arange(self.n_devices).reshape(2, self.n_devices // 2)
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

# Partially replicated sharding
shards = [torch.LongTensor([[i]]) for i in range(2)] * (
self.n_devices // 2)
partition_spec = (None, 1)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
# Partial replication along the 0th axis represents a global tensor
# of torch.Tensor([[0, 1]]).
expected = torch.arange(2).reshape(1, 2)
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

def test_from_cpu_shards_global_shape(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

mesh = self._get_mesh((self.n_devices,))
numel = self.n_devices**2
# The global tensor is torch.arange(numel).
shards = [
torch.arange(self.n_devices) + (i * self.n_devices)
for i in range(self.n_devices)
]
partition_spec = (0,)
op_sharding = mesh.get_op_sharding(partition_spec)

# No global shape specified will include all data from the shards
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel)))

# Too large of a global shape will error out
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,)))

if self.n_devices > 1:
# When the global tensor has fewer elements than the sum of its shards,
# there are two cases:

# Case 1: If the global shape is within n_devices of numel, the excess
# data is treated as padding and ignored.
for delta in range(self.n_devices):
size = torch.Size((numel - delta,))
global_tensor = from_cpu_shards(shards, op_sharding, size)
expected = torch.arange(size[0])
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

# Case 2: Otherwise, it is not possible to have that much padding in a
# sharded tensor, and the shards are incompatible with the shape.
with self.assertRaises(RuntimeError):
shape = torch.Size((numel - self.n_devices,))
from_cpu_shards(shards, op_sharding, shape)
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1,)))


if __name__ == '__main__':
test = unittest.main()
Expand Down
67 changes: 67 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/computation_client.h"
Expand Down Expand Up @@ -1660,6 +1661,72 @@ void InitXlaModuleBindings(py::module m) {
}
return std::nullopt;
});
// Reassemble the CPU shards into a global tensor. A new sharded tensor is
// created from the local shards with the provided sharding annotation
// attached. The order of the shards should coincide with the order of
// devices returned by `torch_xla.runtime.local_runtime_devices()`.
m.def(
"_global_tensor_from_cpu_shards",
[](const std::vector<at::Tensor>& shards, const xla::OpSharding& sharding,
std::optional<std::vector<int64_t>>& global_shape) -> at::Tensor {
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
auto local_devices = runtime::GetComputationClient()->GetLocalDevices();
XLA_CHECK(local_devices.size() == shards.size())
<< "Must specify a shard for each local device";
XLA_CHECK(!global_shape.has_value() ||
global_shape.value().size() == shards[0].sizes().size())
<< "Global shape rank must agree with shard rank: expected rank "
<< shards[0].sizes().size() << ", got "
<< global_shape.value().size();

if (!global_shape.has_value()) {
// Set a default value for the global shape based on the sharding
// type.
if (sharding.type() == xla::OpSharding::OTHER) {
// Infer the global shape to be the shard shape scaled by the tiling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not hold true, in case of uneven tiling. Let's make a note on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or, actually for the right most shard in each dim, we can add the sizes, as the padding is always on the last dims.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok we can't do that, since "Input shard shape must include padding: " << shard.sizes()"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if I should revisit this. I decided on this approach because padding can cross over multiple devices, e.g. sharding a tensor with shape (1, 2) on the mesh (1, 4) will have shards with shape (1, 1) with no real data on the last two devices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @yeounoh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the most part, user should specify the global shape. I am thinking that we should actually make it explicit, not optional... but if not, this way of handling, inferring the default shape is good.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point... I chose to let it be inferred for convenience, since for e.g. distributed data loading, the shards are provided on CPU with the correct padded local shape, and deriving the global shape could be difficult for more sophisticated shardings.

I'll go ahead and land with it optional for now, since we're keeping the API private. Thanks Yeounoh!

// dimensionality.
auto tile_shape = sharding.tile_assignment_dimensions();
global_shape = std::vector<int64_t>();
for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
auto global_dim = tile_shape[dim] * shards[0].sizes()[dim];
global_shape->push_back(global_dim);
}
} else if (sharding.type() == xla::OpSharding::REPLICATED) {
global_shape = shards[0].sizes().vec();
} else {
XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type();
}
}

auto device = GetVirtualDevice();
auto primitive_type =
MakeXlaPrimitiveType(shards[0].type().scalarType(), &device);
xla::Shape tensor_shape = MakeArrayShapeFromDimensions(
global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type,
static_cast<XlaDeviceType>(device.type()));
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);

// Verify that the shard shape is correct for the global shape and
// sharding spec.
auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec);
for (auto shard : shards) {
XLA_CHECK(shard.sizes() == expected_shard_shape)
<< "Input shard shape must include padding: " << shard.sizes()
<< " vs " << expected_shard_shape;
}

auto data_handle = WrapXlaData(ShardingUtil::CreateShardedData(
shards, local_devices, sharding_spec));
XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle));
xla_tensor->SetShardingSpec(*sharding_spec);
auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor));
return torch::autograd::make_variable(tensor,
shards[0].requires_grad());
},
py::arg("shards"), py::arg("sharding"),
py::arg("global_shape") = py::none());
// Returns the local shards of the tensor, with values taken from the
// underlying ComputationClient::GetDataShards. As such, the shards will
// contain any padding that was applied to ensure they all have the same
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ void ShardingUtil::PrepareOutputShardingPropagation(
}

runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
std::vector<at::Tensor>& local_shards, std::vector<std::string>& devices,
const std::vector<at::Tensor>& local_shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec) {
XLA_CHECK(local_shards.size() == devices.size())
<< "A device must be speficied for each shard";
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ class ShardingUtil {
// Transfers the individual shards to the devices and returns a DataPtr for
// the PjRtShardedData wrapping the shards.
static runtime::ComputationClient::DataPtr CreateShardedData(
std::vector<at::Tensor>& shards, std::vector<std::string>& devices,
const std::vector<at::Tensor>& shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);
};

Expand Down
15 changes: 8 additions & 7 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def get_op_sharding(self,
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
"""
partition_spec = _translate_named_partition_spec(self, partition_spec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for refactoring 👍

flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(self, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
Expand Down Expand Up @@ -482,19 +490,12 @@ def mark_sharding(
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
partition_spec = _translate_named_partition_spec(mesh, partition_spec)
# We only allow fully specified `partition_spec` to be applicable, as opposed
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
# should be of the same rank as `t`. This is to support partial replication
# where the group assignment may vary with different input ranks.
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

op_sharding = mesh.get_op_sharding(partition_spec)

Expand Down