From 7fdb15fc3f9bf1f69c86ea51d888a1104b97b2f6 Mon Sep 17 00:00:00 2001 From: jeffhataws Date: Mon, 11 Dec 2023 15:39:13 -0800 Subject: [PATCH] Add out-of-place reduce-scatter coalescing (#6058) --- test/test_mp_reduce_scatter.py | 28 ++++++++ torch_xla/core/xla_model.py | 21 ++++-- torch_xla/csrc/init_python_bindings.cpp | 94 ++++++++++++++++--------- torch_xla/csrc/tensor_methods.cpp | 27 +++++-- torch_xla/csrc/tensor_methods.h | 10 ++- 5 files changed, 135 insertions(+), 45 deletions(-) diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 26979720046d..1ef61d3aa794 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -55,6 +55,34 @@ def _mp_fn(index): xm.rendezvous('test_reduce_scatter_list_input') + # Testing reduce-scatter with list input and output + output_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xoutput_list = [output.to(device) for output in output_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + output=xoutput_list, + pin_layout=False) + + assert (xoutput_list == res_list) + for i, res in enumerate(xoutput_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous('test_reduce_scatter_list_input_output') else: print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index d52d22d6e412..e5a1ea08b28c 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -785,13 +785,24 @@ def reduce_scatter(reduce_type, elif isinstance(input, list) and all( isinstance(v, torch.Tensor) for v in input): if output != None: - raise RuntimeError( - "For xm.reduce_scatter with list of tensors input, output != None is not yet supported." - ) + if not isinstance(output, list) or any( + not isinstance(v, torch.Tensor) for v in output): + raise TypeError( + f"`output` needs to be a list of Tensors, but given {type(output)}." + ) + if len(output) != len(input): + raise ValueError("`output` length doesn't match `input` length: " + f"{len(output)} vs {len(input)}.") + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( + reduce_type, output, input, token, scale, scatter_dim, shard_count, + groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output result = torch_xla._XLAC._xla_reduce_scatter_coalesced( - reduce_type, output or [], input, token, scale, scatter_dim, - shard_count, groups or [], pin_layout) + reduce_type, input, token, scale, scatter_dim, shard_count, groups or + [], pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) return result[:-1] else: diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c3ab1fb78570..6869871f3aa1 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -236,22 +236,32 @@ std::pair> ReduceScatter( std::make_shared(new_token)); } +std::shared_ptr ReduceScatterOut( + const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, + const std::shared_ptr& token, double scale, + int64_t scatter_dim, int64_t shard_count, + const std::vector>& replica_groups, bool pin_layout) { + XLATensorPtr out = bridge::GetXlaTensor(output); + torch::lazy::Value new_token; + new_token = tensor_methods::reduce_scatter_out( + out, bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type), + scale, scatter_dim, shard_count, replica_groups, pin_layout); + return std::make_shared(new_token); +} + std::pair, std::shared_ptr> ReduceScatterCoalesced(const std::string& reduce_type, - const std::vector& outputs, const std::vector& inputs, const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { - std::vector xtensors_out = - GetXlaTensors(outputs, /*want_all=*/true); std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); std::vector result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced( - xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale, - scatter_dim, shard_count, replica_groups, pin_layout); + xtensors, *token, GetReduceType(reduce_type), scale, scatter_dim, + shard_count, replica_groups, pin_layout); std::vector aten_result; for (auto& xt : result) { aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt))); @@ -259,16 +269,19 @@ ReduceScatterCoalesced(const std::string& reduce_type, return {aten_result, std::make_shared(new_token)}; } -std::shared_ptr ReduceScatterOut( - const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, +std::shared_ptr ReduceScatterCoalescedOut( + const std::string& reduce_type, std::vector& outputs, + const std::vector& inputs, const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { - XLATensorPtr out = bridge::GetXlaTensor(output); + std::vector xtensors_out = + GetXlaTensors(outputs, /*want_all=*/true); + std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); torch::lazy::Value new_token; - new_token = tensor_methods::reduce_scatter_out( - out, bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type), - scale, scatter_dim, shard_count, replica_groups, pin_layout); + new_token = tensor_methods::reduce_scatter_coalesced_out( + xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale, + scatter_dim, shard_count, replica_groups, pin_layout); return std::make_shared(new_token); } @@ -1346,45 +1359,62 @@ void InitXlaModuleBindings(py::module m) { result_tuple[1] = new_token; return result_tuple; }); - m.def("_xla_reduce_scatter_coalesced", - [](const std::string& reduce_type, std::vector& outputs, - const std::vector& inputs, + m.def("_xla_reduce_scatter_out", + [](const std::string& reduce_type, at::Tensor& output, + const at::Tensor& input, const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, const py::list& groups, bool pin_layout) { std::vector> replica_groups = CreateReduceGroups(groups); - std::vector result; + at::Tensor result; std::shared_ptr new_token; { NoGilSection nogil; - std::tie(result, new_token) = ReduceScatterCoalesced( - reduce_type, outputs, inputs, token, scale, scatter_dim, - shard_count, replica_groups, pin_layout); - } - auto result_list = py::list(result.size() + 1); - for (int i = 0; i < result.size(); ++i) { - result_list[i] = torch::autograd::make_variable( - result[i], /*requires_grad=*/result[i].requires_grad()); + new_token = ReduceScatterOut(reduce_type, output, input, token, + scale, scatter_dim, shard_count, + replica_groups, pin_layout); } - result_list[result.size()] = new_token; - return result_list; + return new_token; }); - m.def("_xla_reduce_scatter_out", - [](const std::string& reduce_type, at::Tensor& output, - const at::Tensor& input, + m.def( + "_xla_reduce_scatter_coalesced", + [](const std::string& reduce_type, const std::vector& inputs, + const std::shared_ptr& token, double scale, + int64_t scatter_dim, int64_t shard_count, const py::list& groups, + bool pin_layout) { + std::vector> replica_groups = + CreateReduceGroups(groups); + std::vector result; + std::shared_ptr new_token; + { + NoGilSection nogil; + std::tie(result, new_token) = ReduceScatterCoalesced( + reduce_type, inputs, token, scale, scatter_dim, shard_count, + replica_groups, pin_layout); + } + auto result_list = py::list(result.size() + 1); + for (int i = 0; i < result.size(); ++i) { + result_list[i] = torch::autograd::make_variable( + result[i], /*requires_grad=*/result[i].requires_grad()); + } + result_list[result.size()] = new_token; + return result_list; + }); + m.def("_xla_reduce_scatter_coalesced_out", + [](const std::string& reduce_type, std::vector& outputs, + const std::vector& inputs, const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, const py::list& groups, bool pin_layout) { std::vector> replica_groups = CreateReduceGroups(groups); - at::Tensor result; std::shared_ptr new_token; { NoGilSection nogil; - new_token = ReduceScatterOut(reduce_type, output, input, token, - scale, scatter_dim, shard_count, - replica_groups, pin_layout); + new_token = ReduceScatterCoalescedOut( + reduce_type, outputs, inputs, token, scale, scatter_dim, + shard_count, replica_groups, pin_layout); } return new_token; }); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 818b31e474ae..a91c67ea545e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -393,14 +393,12 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, } std::pair, torch::lazy::Value> -reduce_scatter_coalesced(const std::vector& outputs, - const std::vector& inputs, +reduce_scatter_coalesced(const std::vector& inputs, const torch::lazy::Value& token, AllReduceType reduce_type, double scale, int64_t scatter_dim, int64_t shard_count, std::vector> groups, bool pin_layout) { - XLA_CHECK(outputs.empty() || outputs.size() == inputs.size()); std::vector input_values; input_values.reserve(inputs.size()); for (auto& input : inputs) { @@ -412,13 +410,30 @@ reduce_scatter_coalesced(const std::vector& outputs, std::vector result; for (size_t i = 0; i < inputs.size(); ++i) { result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i))); - if (!outputs.empty()) { - outputs[i]->SetIrValue(torch::lazy::Value(node, i)); - } } return {result, torch::lazy::Value(node, inputs.size())}; } +torch::lazy::Value reduce_scatter_coalesced_out( + const std::vector& outputs, + const std::vector& inputs, const torch::lazy::Value& token, + AllReduceType reduce_type, double scale, int64_t scatter_dim, + int64_t shard_count, std::vector> groups, + bool pin_layout) { + std::vector input_values; + input_values.reserve(inputs.size()); + for (auto& input : inputs) { + input_values.push_back(input->GetIrValue()); + } + torch::lazy::NodePtr node = torch::lazy::MakeNode( + reduce_type, input_values, token, scale, scatter_dim, shard_count, + std::move(groups), pin_layout); + for (size_t i = 0; i < inputs.size(); ++i) { + outputs[i]->SetIrValue(torch::lazy::Value(node, i)); + } + return torch::lazy::Value(node, inputs.size()); +} + std::pair all_to_all( const XLATensorPtr& input, const torch::lazy::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 9d8226e2154d..8abae7336442 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -34,14 +34,20 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, bool pin_layout); std::pair, torch::lazy::Value> -reduce_scatter_coalesced(const std::vector& outputs, - const std::vector& inputs, +reduce_scatter_coalesced(const std::vector& inputs, const torch::lazy::Value& token, AllReduceType reduce_type, double scale, int64_t scatter_dim, int64_t shard_count, std::vector> groups, bool pin_layout); +torch::lazy::Value reduce_scatter_coalesced_out( + const std::vector& outputs, + const std::vector& inputs, const torch::lazy::Value& token, + AllReduceType reduce_type, double scale, int64_t scatter_dim, + int64_t shard_count, std::vector> groups, + bool pin_layout); + std::pair all_to_all( const XLATensorPtr& input, const torch::lazy::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count,