Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Add API to assemble CPU shards to a sharded tensor" #5680

Merged
merged 1 commit into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 0 additions & 133 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,139 +900,6 @@ def test_op_sharding_cache(self):
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)

def test_from_cpu_shards_replicated(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
partition_spec = (None,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.arange(4)] * self.n_devices

# No shape should result in the shape of a single shard.
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

# Specify a valid shape for the global tensor
global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

# All invalid shapes should raise
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((5,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((3,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((2, 2)))

def test_from_cpu_shards_tiled(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
partition_spec = (0,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.LongTensor([i]) for i in range(self.n_devices)]

global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(
torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices)))

# Test incorrect number of shards
with self.assertRaises(RuntimeError):
from_cpu_shards(shards[:-1], op_sharding)

# Test an invalid global shape - too many values.
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,)))

# Test an invalid global shape - incorrect rank
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices)))

# Test a valid global shape - restrict the number of meaningful values
# to 1, treating the rest as padding.
global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,)))
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1)))

def test_from_cpu_shards_2d(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an appropriate 2D mesh for the number of devices
if self.n_devices >= 4:
mesh_shape = (self.n_devices // 2, 2)
else:
mesh_shape = (1, self.n_devices)
mesh_2d = self._get_mesh(mesh_shape)

# Replicated sharding
shards = [torch.LongTensor([self.n_devices])] * self.n_devices
partition_spec = (None, None)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

if self.n_devices > 1:
# Tiled sharding
shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)]
partition_spec = (0, 1)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
expected = torch.arange(self.n_devices).reshape(2, self.n_devices // 2)
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

# Partially replicated sharding
shards = [torch.LongTensor([[i]]) for i in range(2)] * (
self.n_devices // 2)
partition_spec = (None, 1)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
# Partial replication along the 0th axis represents a global tensor
# of torch.Tensor([[0, 1]]).
expected = torch.arange(2).reshape(1, 2)
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

def test_from_cpu_shards_global_shape(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

mesh = self._get_mesh((self.n_devices,))
numel = self.n_devices**2
# The global tensor is torch.arange(numel).
shards = [
torch.arange(self.n_devices) + (i * self.n_devices)
for i in range(self.n_devices)
]
partition_spec = (0,)
op_sharding = mesh.get_op_sharding(partition_spec)

# No global shape specified will include all data from the shards
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel)))

# Too large of a global shape will error out
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,)))

if self.n_devices > 1:
# When the global tensor has fewer elements than the sum of its shards,
# there are two cases:

# Case 1: If the global shape is within n_devices of numel, the excess
# data is treated as padding and ignored.
for delta in range(self.n_devices):
size = torch.Size((numel - delta,))
global_tensor = from_cpu_shards(shards, op_sharding, size)
expected = torch.arange(size[0])
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

# Case 2: Otherwise, it is not possible to have that much padding in a
# sharded tensor, and the shards are incompatible with the shape.
with self.assertRaises(RuntimeError):
shape = torch.Size((numel - self.n_devices,))
from_cpu_shards(shards, op_sharding, shape)
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1,)))


if __name__ == '__main__':
test = unittest.main()
Expand Down
67 changes: 0 additions & 67 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/computation_client.h"
Expand Down Expand Up @@ -1664,72 +1663,6 @@ void InitXlaModuleBindings(py::module m) {
}
return std::nullopt;
});
// Reassemble the CPU shards into a global tensor. A new sharded tensor is
// created from the local shards with the provided sharding annotation
// attached. The order of the shards should coincide with the order of
// devices returned by `torch_xla.runtime.local_runtime_devices()`.
m.def(
"_global_tensor_from_cpu_shards",
[](const std::vector<at::Tensor>& shards, const xla::OpSharding& sharding,
std::optional<std::vector<int64_t>>& global_shape) -> at::Tensor {
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
auto local_devices = runtime::GetComputationClient()->GetLocalDevices();
XLA_CHECK(local_devices.size() == shards.size())
<< "Must specify a shard for each local device";
XLA_CHECK(!global_shape.has_value() ||
global_shape.value().size() == shards[0].sizes().size())
<< "Global shape rank must agree with shard rank: expected rank "
<< shards[0].sizes().size() << ", got "
<< global_shape.value().size();

if (!global_shape.has_value()) {
// Set a default value for the global shape based on the sharding
// type.
if (sharding.type() == xla::OpSharding::OTHER) {
// Infer the global shape to be the shard shape scaled by the tiling
// dimensionality.
auto tile_shape = sharding.tile_assignment_dimensions();
global_shape = std::vector<int64_t>();
for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
auto global_dim = tile_shape[dim] * shards[0].sizes()[dim];
global_shape->push_back(global_dim);
}
} else if (sharding.type() == xla::OpSharding::REPLICATED) {
global_shape = shards[0].sizes().vec();
} else {
XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type();
}
}

auto device = GetVirtualDevice();
auto primitive_type =
MakeXlaPrimitiveType(shards[0].type().scalarType(), &device);
xla::Shape tensor_shape = MakeArrayShapeFromDimensions(
global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type,
static_cast<XlaDeviceType>(device.type()));
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);

// Verify that the shard shape is correct for the global shape and
// sharding spec.
auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec);
for (auto shard : shards) {
XLA_CHECK(shard.sizes() == expected_shard_shape)
<< "Input shard shape must include padding: " << shard.sizes()
<< " vs " << expected_shard_shape;
}

auto data_handle = WrapXlaData(ShardingUtil::CreateShardedData(
shards, local_devices, sharding_spec));
XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle));
xla_tensor->SetShardingSpec(*sharding_spec);
auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor));
return torch::autograd::make_variable(tensor,
shards[0].requires_grad());
},
py::arg("shards"), py::arg("sharding"),
py::arg("global_shape") = py::none());
// Returns the local shards of the tensor, with values taken from the
// underlying ComputationClient::GetDataShards. As such, the shards will
// contain any padding that was applied to ensure they all have the same
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,7 @@ void ShardingUtil::PrepareOutputShardingPropagation(
}

runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
const std::vector<at::Tensor>& local_shards,
const std::vector<std::string>& devices,
std::vector<at::Tensor>& local_shards, std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec) {
XLA_CHECK(local_shards.size() == devices.size())
<< "A device must be speficied for each shard";
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ class ShardingUtil {
// Transfers the individual shards to the devices and returns a DataPtr for
// the PjRtShardedData wrapping the shards.
static runtime::ComputationClient::DataPtr CreateShardedData(
const std::vector<at::Tensor>& shards,
const std::vector<std::string>& devices,
std::vector<at::Tensor>& shards, std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);
};

Expand Down
15 changes: 7 additions & 8 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,6 @@ def get_op_sharding(self,
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
"""
partition_spec = _translate_named_partition_spec(self, partition_spec)
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(self, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
Expand Down Expand Up @@ -490,12 +482,19 @@ def mark_sharding(
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
partition_spec = _translate_named_partition_spec(mesh, partition_spec)
# We only allow fully specified `partition_spec` to be applicable, as opposed
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
# should be of the same rank as `t`. This is to support partial replication
# where the group assignment may vary with different input ranks.
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

op_sharding = mesh.get_op_sharding(partition_spec)

Expand Down