From c7a04ac7eb7d7b67f0a355d4b443e49475701928 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Fri, 10 Nov 2023 22:49:30 -0800 Subject: [PATCH] Implement mark_sharding as a custom op to support dynamo spmd activation sharding (#5712) --- test/spmd/test_dynamo_spmd.py | 52 ++++++++- torch_xla/csrc/aten_autograd_ops.cpp | 12 --- torch_xla/csrc/aten_autograd_ops.h | 11 ++ torch_xla/csrc/init_python_bindings.cpp | 94 ++++++----------- torch_xla/csrc/xla_sharding_util.cpp | 135 ++++++++++++++++++++++++ torch_xla/csrc/xla_sharding_util.h | 8 ++ torch_xla/experimental/xla_sharding.py | 53 +++++++--- 7 files changed, 276 insertions(+), 89 deletions(-) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 22cd29804137..2874d5783bcd 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -15,7 +15,7 @@ class SimpleLinear(nn.Module): - def __init__(self): + def __init__(self, mesh=None): super(SimpleLinear, self).__init__() self.fc1 = nn.Linear(128, 128) self.relu = nn.ReLU() @@ -23,8 +23,14 @@ def __init__(self): # Add an additional 1x1 layer at the end to ensure the final layer # is not sharded. self.fc3 = nn.Linear(1, 1) + # If mesh is not none, we'll do a mark sharding inside the forward function + # to ensure dynamo can recognize and trace it in a torch compile. + self.mesh = mesh def forward(self, x): + if self.mesh and 'xla' in str(self.fc2.weight.device): + xs.mark_sharding( + self.fc2.weight, self.mesh, (1, 0), use_dynamo_custom_op=True) y = self.relu(self.fc1(x)) z = self.fc2(y) return self.fc3(z) @@ -171,6 +177,50 @@ def test_dynamo_input_sharding_threashold(self): else: del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] + def test_dynamo_spmd_mark_sharding_outside_of_compile(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(1, 128, device=device) + xs.mark_sharding( + linear.fc2.weight, + self._get_mesh((1, self.n_devices)), (1, 0), + use_dynamo_custom_op=True) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + + # Ensure that another run with same input does not trigger additional compilation + compile_count = met.metric_data('CompileTime')[0] + dynamo_res = dynamo_linear(xla_x) + self.assertEqual(met.metric_data('CompileTime')[0], compile_count) + + def test_mark_sharding_inside_compile(self): + met.clear_counters() + device = xm.xla_device() + mesh = self._get_mesh((1, self.n_devices)) + + # Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op + # variant of mark_sharding inside the forward function. + linear = SimpleLinear(mesh=mesh).to(device) + linear.eval() + + xla_x = torch.randn(1, 128, device=device) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + + # Ensure that another run with same input does not trigger additional compilation + compile_count = met.metric_data('CompileTime')[0] + dynamo_res = dynamo_linear(xla_x) + self.assertEqual(met.metric_data('CompileTime')[0], compile_count) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 08c6b27f92cf..81cfdfb4f428 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -253,17 +253,5 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, return grad; } -TORCH_LIBRARY(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); - - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); -} } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/aten_autograd_ops.h b/torch_xla/csrc/aten_autograd_ops.h index be063b76620d..d1cc8a980482 100644 --- a/torch_xla/csrc/aten_autograd_ops.h +++ b/torch_xla/csrc/aten_autograd_ops.h @@ -46,6 +46,17 @@ struct MaxPool3dAutogradFunction torch::autograd::variable_list grad_output); }; +torch::Tensor max_pool2d_forward(torch::Tensor self, + torch::IntArrayRef kernel_size, + torch::IntArrayRef stride, + torch::IntArrayRef padding, + torch::IntArrayRef dilation, bool ceil_mode); + +torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, + torch::IntArrayRef kernel_size, + torch::IntArrayRef stride, + torch::IntArrayRef padding, bool ceil_mode); + } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b81f0978d27b..1391b73a16c5 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -29,6 +29,7 @@ #include "pybind11/pytypes.h" #include "pybind11/stl_bind.h" #include "torch_xla/csrc/XLANativeFunctions.h" +#include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/dtype.h" @@ -1561,72 +1562,39 @@ void InitXlaModuleBindings(py::module m) { tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)); })); - m.def("_xla_mark_sharding", [](const at::Tensor& input, - xla::OpSharding sharding) { - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - xtensor->shape(), - static_cast(xtensor->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (xtensor->CurrentIrValue()) { - device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(xtensor, new_sharding_spec); - return; - } - } + m.def("_xla_mark_sharding", + [](const at::Tensor& input, xla::OpSharding sharding) { + ShardingUtil::xla_mark_sharding(input, sharding); + }); + m.def("_xla_mark_sharding_dynamo_custom_op", + [](const at::Tensor& input, const py::list& tile_assignment, + const py::list& group_assignment, const py::list& replication_groups, + int sharding_type) { + c10::List tile_assignment_list = + c10::List(); + for (auto t : tile_assignment) { + tile_assignment_list.push_back( + at::IntArrayRef(t.cast>())); + } - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (xtensor->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = xtensor->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = xtensor->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } + c10::List group_assignment_list = + c10::List(); + for (auto t : group_assignment) { + group_assignment_list.push_back( + at::IntArrayRef(t.cast>())); + } - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((xtensor->CurrentDataHandle() && - xtensor->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{xtensor}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; - } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*new_sharding_spec); + c10::List replication_groups_list = + c10::List(); + for (auto t : replication_groups) { + replication_groups_list.push_back( + at::IntArrayRef(t.cast>())); + } - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); - }); + xla_mark_sharding_dynamo_custom_op( + input, tile_assignment_list, group_assignment_list, + replication_groups_list, sharding_type); + }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); xtensor->ClearShardingSpec(); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 42c1e9e66fa1..4fb304d37dde 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -6,6 +6,8 @@ #include #include "torch/csrc/lazy/core/ir_util.h" +#include "torch_xla/csrc/aten_autograd_ops.h" +#include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" @@ -15,7 +17,9 @@ #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/xla_graph_executor.h" #include "tsl/profiler/lib/traceme.h" #include "xla/execution_options_util.h" #include "xla/hlo/ir/hlo_module.h" @@ -743,4 +747,135 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } +void ShardingUtil::xla_mark_sharding(const at::Tensor& input, + xla::OpSharding sharding) { + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (xtensor->CurrentIrValue()) { + device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(xtensor, new_sharding_spec); + return; + } + } + + // For data, we need to deal with the data transfers between + // host and device. + at::Tensor cpu_tensor; + if (xtensor->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); + // When virtual device is enabled for SPMD, we defer the initial + // data transfer to the device and retain the original data on the + // host, until the sharded data transfer. + cpu_tensor = xtensor->CurrentTensorData().value(); + } else { + // A new input tensor is not expected to be sharded. But sometimes, + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. + auto current_sharding_spec = xtensor->sharding_spec(); + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; + } + + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((xtensor->CurrentDataHandle() && + xtensor->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{xtensor}; + cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + } + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + xtensor->SetXlaData(xla_data); + xtensor->SetShardingSpec(*new_sharding_spec); + + // Register sharded tensor data. + XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); +} + +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type) { + py::list tile_assignment_py = py::list(); + for (int i = 0; i < tile_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : tile_assignment[i].get().toIntList()) { + pylist.append(t); + } + tile_assignment_py.append(pylist); + } + + py::list group_assignment_py = py::list(); + for (int i = 0; i < group_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : group_assignment[i].get().toIntList()) { + pylist.append(t); + } + group_assignment_py.append(pylist); + } + + py::list replication_groups_py = py::list(); + for (int i = 0; i < replication_groups.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : replication_groups[i].get().toIntList()) { + pylist.append(t); + } + replication_groups_py.append(pylist); + } + + xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( + tile_assignment_py, group_assignment_py, replication_groups_py, + ShardingUtil::ShardingType(sharding_type)); + + ShardingUtil::xla_mark_sharding(input, op_sharding); +} + +// Macro for defining a function that will be run at static initialization time +// to define a library of operators in the namespace. Used to define a new set +// of custom operators that do not already exist in PyTorch. +TORCH_LIBRARY(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); + m.def( + "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " + "tile_assignment, int[][] group_assignment, int[][] replication_groups, " + "int sharding_type) -> ()", + torch::dispatch(c10::DispatchKey::XLA, + TORCH_FN(torch_xla::xla_mark_sharding_dynamo_custom_op))); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 32060c7fc098..3e600be68715 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -150,8 +150,16 @@ class ShardingUtil { const std::vector& shards, const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); + + static void xla_mark_sharding(const at::Tensor& input, + xla::OpSharding sharding); }; +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_SHARDING_UTIL_H_ diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 21d0e2e570ac..1b12513fc2ee 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -82,7 +82,8 @@ def get_axis_name_idx(self, name: str) -> int: @functools.lru_cache(maxsize=None) def get_op_sharding(self, - partition_spec: Tuple) -> torch_xla._XLAC.OpSharding: + partition_spec: Tuple, + flatten_opsharding=False) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. @@ -104,9 +105,15 @@ def get_op_sharding(self, replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} group_assignment, replication_groups = _get_group_assignment( sharding_type, tile_assignment, len(partition_spec), replicate_dims) - return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), - group_assignment, replication_groups, - int(sharding_type)) + + # If flatten_opsharding = True, return the flattened version of OpSharding + if flatten_opsharding: + return (tile_assignment.tolist(), group_assignment, replication_groups, + int(sharding_type)) + else: + return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), + group_assignment, replication_groups, + int(sharding_type)) # HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 @@ -449,9 +456,10 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): @xr.requires_pjrt -def mark_sharding( - t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, - partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: +def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], + mesh: Mesh, + partition_spec: Tuple[Union[Tuple, int, str, None]], + use_dynamo_custom_op: bool = False) -> XLAShardedTensor: """ Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. @@ -471,6 +479,9 @@ def mark_sharding( >> mesh_shape = (4, 2) >> partition_spec = (0, None) + dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding + to make itself recognizeable and traceable by dynamo. + Examples —------------------------------ mesh_shape = (4, 2) @@ -497,13 +508,29 @@ def mark_sharding( assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - op_sharding = mesh.get_op_sharding(partition_spec) + if use_dynamo_custom_op: + tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding( + partition_spec, flatten_opsharding=True) - if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) - return t - torch_xla._XLAC._xla_mark_sharding(t, op_sharding) - return XLAShardedTensor(t) + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( + t.global_tensor, tile_assignment, group_assignment, + replication_groups, sharding_type) + return t + else: + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( + t, tile_assignment, group_assignment, replication_groups, + sharding_type) + return XLAShardedTensor(t) + else: + op_sharding = mesh.get_op_sharding(partition_spec) + + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) + return t + else: + torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + return XLAShardedTensor(t) def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: