Skip to content

Commit

Permalink
Improve tests and error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Sep 26, 2023
1 parent bf329c7 commit 7bd0555
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 73 deletions.
153 changes: 109 additions & 44 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,70 +866,135 @@ def test_op_sharding_cache(self):
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)

def test_from_cpu_shards(self):
def test_from_cpu_shards_replicated(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
op_sharding = mesh.get_op_sharding((0,))
partition_spec = (None,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.arange(4)] * self.n_devices

# No shape should result in the shape of a single shard.
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

# Infer the global shape from the OpSharding
# Specify a valid shape for the global tensor
global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

# All invalid shapes should raise
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((5,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((3,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((2, 2)))

def test_from_cpu_shards_tiled(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
partition_spec = (0,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.LongTensor([i]) for i in range(self.n_devices)]
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding)


global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(
torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices)))

# Test incorrect number of shards
with self.assertRaises(RuntimeError):
XLAShardedTensor.from_cpu_shards(shards[:-1], op_sharding)

# Test replicated sharding. The result should equal a shard.
op_sharding = mesh.get_op_sharding((None,))
shards = [torch.LongTensor([0])] * self.n_devices
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))
from_cpu_shards(shards[:-1], op_sharding)

# Test an invalid global shape - too many values.
with self.assertRaises(RuntimeError):
bad_size = torch.Size((2 * self.n_devices,))
XLAShardedTensor.from_cpu_shards(shards, op_sharding, bad_size)
from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,)))

# Test an invalid global shape - incorrect rank
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices)))

# Test a valid global shape - restrict the number of meaningful values
# to 1, treating the rest as padding.
global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,)))
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1)))

def test_from_cpu_shards_2d(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

# Create an appropriate 2D mesh for the number of devices
if self.n_devices >= 4:
mesh_shape = (self.n_devices // 2, 2)
else:
mesh_shape = (1, self.n_devices)
mesh_2d = self._get_mesh(mesh_shape)

# Replicated sharding
shards = [torch.LongTensor([self.n_devices])] * self.n_devices
partition_spec = (None, None)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))

if self.n_devices > 1:
# Test an invalid global shape - incorrect rank
with self.assertRaises(RuntimeError):
bad_size = torch.Size((self.n_devices // 2, 2))
XLAShardedTensor.from_cpu_shards(shards, op_sharding, bad_size)

# Test a valid global shape - restrict the number of meaningful values
# to 1, treating the rest as padding.
global_shape = torch.Size((1,))
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding,
global_shape)
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1)))

# Use a 2d mesh
mesh_2d = self._get_mesh((self.n_devices // 2, 2))
# Tiled sharding
shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)]
op_sharding = mesh_2d.get_op_sharding((0, 1))
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding)
partition_spec = (0, 1)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
expected = torch.arange(self.n_devices).reshape(2, self.n_devices // 2)
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

# Use partial replication. With eight devices, the mesh and shards are:
# mesh_2d shards
# | TPU:0 | TPU:1 | | 0 | 1 |
# | TPU:2 | TPU:3 | | 0 | 1 |
# | TPU:4 | TPU:5 | | 0 | 1 |
# | TPU:6 | TPU:7 | | 0 | 1 |
#
# Partial replication along the 0th axis represents a global tensor
# of torch.Tensor([[0, 1]]).
# Partially replicated sharding
shards = [torch.LongTensor([[i]]) for i in range(2)] * (
self.n_devices // 2)
op_sharding = mesh_2d.get_op_sharding((None, 1))
global_tensor = XLAShardedTensor.from_cpu_shards(shards, op_sharding)
expected = torch.arange(self.n_devices)
self.assertTrue(
torch.allclose(global_tensor.cpu(),
torch.arange(2).reshape(1, 2)))
partition_spec = (None, 1)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
# Partial replication along the 0th axis represents a global tensor
# of torch.Tensor([[0, 1]]).
expected = torch.arange(2).reshape(1, 2)
self.assertTrue(torch.allclose(global_tensor.cpu(), expected))

def test_from_cpu_shards_global_shape(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

mesh = self._get_mesh((self.n_devices,))
numel = self.n_devices ** 2
# The global tensor is torch.arange(numel).
shards = [torch.arange(self.n_devices) + (i * self.n_devices) for i in range(self.n_devices)]
partition_spec = (0,)
op_sharding = mesh.get_op_sharding(partition_spec)

# No global shape specified will include all data from the shards
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel)))

# Too large of a global shape will error out
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,)))

if self.n_devices > 1:
# When the global tensor has fewer elements than the sum of its shards,
# there are two cases:

# Case 1: If the global shape is within n_devices of numel, the excess
# data is treated as padding and ignored.
for delta in range(self.n_devices):
size = numel - delta
global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((size,)))
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(size)))

# Case 2: Otherwise, it is not possible to have that much padding in a
# sharded tensor, and the shards are incompatible with the shape.
with self.assertRaises(RuntimeError):
shape = torch.Size((numel - self.n_devices,))
from_cpu_shards(shards, op_sharding, shape)
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1,)))


if __name__ == '__main__':
Expand Down
27 changes: 17 additions & 10 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/computation_client.h"
Expand Down Expand Up @@ -1667,12 +1668,17 @@ void InitXlaModuleBindings(py::module m) {
m.def(
"_global_tensor_from_cpu_shards",
[](const std::vector<at::Tensor>& shards, const xla::OpSharding& sharding,
std::optional<std::vector<long int>>& global_shape) -> at::Tensor {
std::optional<std::vector<int64_t>>& global_shape) -> at::Tensor {
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
auto local_devices = runtime::GetComputationClient()->GetLocalDevices();
XLA_CHECK(local_devices.size() == shards.size())
<< "Must specify a shard for each local device";
XLA_CHECK(!global_shape.has_value() ||
global_shape.value().size() == shards[0].sizes().size())
<< "Global shape rank must agree with shard rank: expected rank "
<< shards[0].sizes().size() << ", got "
<< global_shape.value().size();

if (!global_shape.has_value()) {
// Set a default value for the global shape based on the sharding
Expand All @@ -1681,23 +1687,24 @@ void InitXlaModuleBindings(py::module m) {
// Infer the global shape to be the shard shape scaled by the tiling
// dimensionality.
auto tile_shape = sharding.tile_assignment_dimensions();
global_shape = std::vector<long int>();
global_shape = std::vector<int64_t>();
for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
auto global_dim = tile_shape[dim] * shards[0].sizes()[dim];
global_shape->push_back(global_dim);
}
} else {
// If the sharding type is not tiled, the tensor shards are
// replicated.
} else if (sharding.type() == xla::OpSharding::REPLICATED) {
global_shape = shards[0].sizes().vec();
} else {
XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type();
}
}

xla::Shape tensor_shape =
CreateComputationShapeFromTensor(shards[0], nullptr);
for (int dim = 0; dim < tensor_shape.rank(); ++dim) {
tensor_shape.set_dimensions(dim, global_shape.value()[dim]);
}
auto device = GetVirtualDevice();
auto primitive_type =
MakeXlaPrimitiveType(shards[0].type().scalarType(), &device);
xla::Shape tensor_shape = MakeArrayShapeFromDimensions(
global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type,
static_cast<XlaDeviceType>(device.type()));
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);

Expand Down
13 changes: 0 additions & 13 deletions torch_xla/experimental/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,6 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
r.global_tensor = elem.detach() if r.requires_grad else elem
return r

@staticmethod
def from_cpu_shards(shards: List[torch.Tensor],
sharding: torch_xla._XLAC.OpSharding,
global_shape: torch.Size = None):
"""
Create an XLAShardedTensor from the given list of CPU shards. The order of
the shards determines which device it will be placed on, coinciding with
the device order returned by `torch_xla.runtime.local_runtime_devices`.
"""
return XLAShardedTensor(
torch_xla._XLAC._global_tensor_from_cpu_shards(shards, sharding,
global_shape))

# Shards on the devices are materialized/available after the lazy
# execution of the partitioned HLO graph. Each XLAShard points
# to torch.Tensor. The shards represent a snapshot on CPU, detached
Expand Down
6 changes: 0 additions & 6 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,6 @@ def global_device_count() -> int:
return len(torch_xla._XLAC._xla_get_all_devices())


@requires_pjrt
def local_runtime_devices() -> List[str]:
"""Returns addressable devices as a list of string."""
return torch_xla._XLAC._xla_get_runtime_devices()


@requires_pjrt
def world_size() -> int:
"""Returns the total number of configured logical devices."""
Expand Down

0 comments on commit 7bd0555

Please sign in to comment.