From 2e861ffca187f9eeefe33c53c32cabc73827e710 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Wed, 18 Oct 2023 05:35:50 +0000 Subject: [PATCH] Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived https://github.com/pytorch/xla/pull/4145) --- test/test_torch_distributed_xla_backend.py | 44 +++++ torch_xla/core/xla_model.py | 85 ++++++++-- torch_xla/csrc/cross_replica_reduces.cpp | 151 +++++++++++------- torch_xla/csrc/cross_replica_reduces.h | 31 ++-- torch_xla/csrc/init_python_bindings.cpp | 87 ++++++++++ torch_xla/csrc/ops/all_gather.cpp | 42 +++-- torch_xla/csrc/ops/all_gather.h | 6 +- torch_xla/csrc/ops/all_reduce.cpp | 9 -- torch_xla/csrc/ops/reduce_scatter.cpp | 49 +++--- torch_xla/csrc/ops/reduce_scatter.h | 3 +- torch_xla/csrc/tensor_methods.cpp | 51 ++++-- torch_xla/csrc/tensor_methods.h | 14 ++ torch_xla/distributed/fsdp/utils.py | 65 +++++++- .../fsdp/xla_fully_sharded_data_parallel.py | 149 ++++++++++------- torch_xla/distributed/xla_backend.py | 38 ++++- 15 files changed, 613 insertions(+), 211 deletions(-) diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 478ab4dd4986..7c344775e318 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -87,6 +87,26 @@ def test_allgather(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors) hlo_matches(hlo, all_gather_pattern) + def test_allgather_coalesced(self): + device = xm.xla_device() + tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() + tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() + pg_xla = get_process_group_xla(rank=3, size=8) + output_tensors = [torch.zeros_like(tensor)] * 8 + output_tensors2 = [torch.zeros_like(tensor2)] * 8 + # because we set os.environ[xenv.WORLD_SIZE] = '1', here the outputs' + # shapes will be same as the inputs' shapes. + all_gather_pattern = ( + r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) ' + r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+, ' + r's64\[] %.+\.\d+\)') + pg_xla.allgather_coalesced([output_tensors, output_tensors2], + [tensor, tensor2]) + hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors) + hlo_matches(hlo, all_gather_pattern) + # purge all computations attached the device. + xm.mark_step() + def test_broadcast(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() @@ -106,6 +126,30 @@ def test_reduce_scatter(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([output]) hlo_matches(hlo, reduce_scatter_pattern) + def test_reduce_scatter_coalesced(self): + device = xm.xla_device() + tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() + tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() + input_tensors_list = [[tensor, tensor], [tensor2, tensor2]] + output_list = [torch.zeros_like(tensor), torch.zeros_like(tensor2)] + pg_xla = get_process_group_xla(rank=0, size=len(input_tensors_list[0])) + opts = dist.ReduceScatterOptions() + opts.reduceOp = dist.ReduceOp.SUM + reduce_scatter_pattern = ( + r'%reduce\-scatter\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) ' + r'reduce\-scatter\(s64\[4]\{0} %.+\.\d+, s64\[10]\{0} %.+\.\d+, ' + r's64\[] %.+\.\d+\)') + with self.assertRaises(RuntimeError) as cm: + pg_xla.reduce_scatter_coalesced(output_list, input_tensors_list, opts) + hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list) + hlo_matches(hlo, reduce_scatter_pattern) + # purge all computations attached the device. + xm.mark_step() + assert 'UNIMPLEMENTED: ReduceScatter is not implemented on CPU.' in str( + cm.exception), str(cm.exception) + # reset token to clean up the mess after the RuntimeError. + xm.set_replication(device, []) + @patch_world(0, 6) def test_send(self): device = xm.xla_device() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index ff6d015b2b22..42d4d8858b53 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -485,6 +485,8 @@ def _all_gather_using_all_reduce(value, dim=0, groups=None, pin_layout=True): Args: value (torch.Tensor): The input tensor. + value (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then + it will also be the output. dim (int): The gather dimension. Default: 0 groups (list, optional): A list of list, representing the replica groups for @@ -562,6 +564,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): shard_count = xrt_world_size() token, devctx = _get_all_reduce_token() + if output != None: # Call the out of place version of the all_gather new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim, @@ -574,6 +577,31 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): [], pin_layout) return result + if isinstance(value, torch.Tensor): + if output != None: + # Call the out of place version of the all_gather + new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim, + shard_count, groups or [], + pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output + + result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count, + groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) + return result[0] + + # Now the input should be a list of Tensors. + if not isinstance(value, list) or any( + not isinstance(v, torch.Tensor) for v in value): + raise TypeError("`value` needs to be a Tensor or a list of Tensors, but " + f"given {type(value)}.") + result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim, + shard_count, groups or [], + pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] + def all_to_all(value, split_dimension, @@ -718,16 +746,19 @@ def reduce_scatter(reduce_type, reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``, ``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and ``xm.REDUCE_MAX``. - input: A single `torch.Tensor` all reduce + scatter op to. + input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then + it will also be the output. scale (float): A default scaling value to be applied after the reduce. scatter_dim (int): Dimension number to which apply scatter operation. shard_count (int): The number of ways to split up the scatter_dim in. groups (list): A list of list, representing the replica groups for - the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]` + the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]` defines two groups, one with the `[0, 1, 2, 3]` replicas and one with the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. output: Optional output tensor + output: Optional output tensor if `input` is a torch.Tensor or a list of + torch.Tensor if `input` is a list of torch.Tensor. pin_layout (bool, optional): whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might @@ -740,21 +771,43 @@ def reduce_scatter(reduce_type, the same as the input. """ token, devctx = _get_all_reduce_token() - if output != None: - # Call the out of place version of the reduce_scatter - new_token = torch_xla._XLAC._xla_reduce_scatter_out(reduce_type, output, - input, token, scale, - scatter_dim, - shard_count, groups or - [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output - result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale, - scatter_dim, shard_count, - groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) - return result[0] + if isinstance(input, torch.Tensor): + if output != None: + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_reduce_scatter_out( + reduce_type, output, input, token, scale, scatter_dim, shard_count, + groups or [], pin_layout) + devctx.all_reduce_token = new_token + return output + + result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, + scale, scatter_dim, + shard_count, groups or [], + pin_layout) + devctx.all_reduce_token = result[1] + return result[0] + + # Now the input should be a list of Tensors. + if not isinstance(input, list) or any( + not isinstance(v, torch.Tensor) for v in input): + raise TypeError("`input` needs to be a Tensor or a list of Tensors, but " + f"given {type(input)}.") + if output != None: + if not isinstance(output, list) or any( + not isinstance(v, torch.Tensor) for v in output): + raise TypeError( + f"`output` needs to be a list of Tensors, but given {type(output)}." + ) + if len(output) != len(input): + raise ValueError("`output` length doesn't match `input` length: " + f"{len(output)} vs {len(input)}.") + + result = torch_xla._XLAC._xla_reduce_scatter_coalesced( + reduce_type, output or [], input, token, scale, scatter_dim, shard_count, + groups or [], pin_layout) + devctx.all_reduce_token = result[-1] + return result[:-1] def add_step_closure(closure, args=(), run_async=False): diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 200a750f856c..5fae1168ad23 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -210,31 +210,46 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token, return {reduce_result, token_handler.GetNewToken(reduce_result)}; } -AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim, - int64_t shard_count, - const std::vector>& groups, - bool pin_layout) { - std::vector reduce_groups = CreateReduceGroups(groups); - const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - TokenHandler token_handler(token); - xla::XlaOp all_gather_result; - if (pin_layout) { - torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); - xla::Shape reduce_shape = MakeArrayShapeFromDimensions( - input_shape.dimensions(), input_shape.dynamic_dimensions(), - input_shape.element_type(), - static_cast(xla_device.type())); - all_gather_result = - xla::AllGather(token_handler.GetInput(input, &input_shape), dim, - shard_count, reduce_groups, /*channel_id=*/absl::nullopt, - /*layout=*/reduce_shape.layout()); - } else { - all_gather_result = - xla::AllGather(token_handler.GetInput(input, &input_shape), dim, - shard_count, reduce_groups); +std::vector BuildAllGather( + absl::Span inputs, xla::XlaOp token, int64_t dim, + int64_t shard_count, const std::vector>& groups, + bool pin_layout) { + std::vector cc_groups = CreateReduceGroups(groups); + // TODO: We use pseudo-tokens ATM, which are real values. This need to be + // switched to use the real XLA Token once support has been added to XLA + // AllGather(). + xla::XlaOp chained_token = token; + ReduceContext cc_ctx = GetReduceContext(inputs); + std::vector result(inputs.size()); + for (auto& type_ctx : cc_ctx.contexts) { + xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first); + type_ctx.second.ops.push_back(token_op); + type_ctx.second.operand_shapes.push_back( + ShapeHelper::ShapeOfXlaOp(token_op)); + + xla::XlaOp all_gather_result; + if (pin_layout) { + all_gather_result = xla::AllGather( + xla::Tuple(inputs[0].builder(), type_ctx.second.ops), dim, + shard_count, cc_groups, /*channel_id=*/absl::nullopt, + /*layout=*/ + MakeReduceShape(type_ctx.second.operand_shapes).layout()); + } else { + all_gather_result = + xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops), + dim, shard_count, cc_groups); + } + for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) { + size_t op_idx = type_ctx.second.indices[i]; + result[op_idx] = xla::GetTupleElement(all_gather_result, i); + } + chained_token = + xla::GetTupleElement(all_gather_result, type_ctx.second.indices.size()); } - return {all_gather_result, token_handler.GetNewToken(all_gather_result)}; -} + result.push_back( + MaybeConvertTo(chained_token, XlaHelpers::TypeOfXlaOp(token))); + return result; + } CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, @@ -274,39 +289,63 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape, return {result, new_token}; } -ReduceScatterResult BuildReduceScatter( - AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale, - int64_t scatter_dim, int64_t shard_count, - const std::vector>& groups, bool pin_layout) { - std::vector reduce_groups = CreateReduceGroups(groups); - TokenHandler token_handler(token); - const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - xla::XlaOp reduce_result; - if (pin_layout) { - torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); - xla::Shape reduce_shape = MakeArrayShapeFromDimensions( - input_shape.dimensions(), input_shape.dynamic_dimensions(), - input_shape.element_type(), - static_cast(xla_device.type())); - reduce_result = xla::ReduceScatter( - token_handler.GetInput(input, &input_shape), - GetReduceComutation(reduce_type, input_shape.element_type()), - scatter_dim, shard_count, reduce_groups, /*channel_id=*/absl::nullopt, - /*layout=*/reduce_shape.layout()); - } else { - reduce_result = xla::ReduceScatter( - token_handler.GetInput(input, &input_shape), - GetReduceComutation(reduce_type, input_shape.element_type()), - scatter_dim, shard_count, reduce_groups); - } - - if (scale != 1.0) { - xla::XlaOp scaling_value = XlaHelpers::ScalarValue( - scale, input_shape.element_type(), input.builder()); - reduce_result = reduce_result * scaling_value; - } +std::vector BuildReduceScatter( + AllReduceType reduce_type, absl::Span inputs, + xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, + const std::vector>& groups, bool pin_layout) { + std::vector cc_groups = CreateReduceGroups(groups); + // TODO: We use pseudo-tokens ATM, which are real values. This need to be + // switched to use the real XLA Token once support has been added to XLA + // ReduceScatter(). + xla::XlaOp chained_token = token; + ReduceContext cc_ctx = GetReduceContext(inputs); + std::vector result(inputs.size()); + for (auto& type_ctx : cc_ctx.contexts) { + xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first); + type_ctx.second.ops.push_back(token_op); + type_ctx.second.operand_shapes.push_back( + ShapeHelper::ShapeOfXlaOp(token_op)); + xla::XlaOp reduce_result; + if (pin_layout) { + reduce_result = xla::ReduceScatter( + xla::Tuple(inputs[0].builder(), type_ctx.second.ops), + GetReduceComutation(reduce_type, type_ctx.first), scatter_dim, + shard_count, cc_groups, /*channel_id=*/absl::nullopt, + /*layout=*/ + MakeReduceShape(type_ctx.second.operand_shapes).layout()); + } else { + reduce_result = xla::ReduceScatter( + xla::Tuple(inputs[0].builder(), type_ctx.second.ops), + GetReduceComutation(reduce_type, type_ctx.first), scatter_dim, + shard_count, cc_groups); + } + for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) { + size_t op_idx = type_ctx.second.indices[i]; + xla::XlaOp gte = xla::GetTupleElement(reduce_result, i); + if (scale != 1.0) { + xla::XlaOp scaling_value = XlaHelpers::ScalarValue( + scale, type_ctx.second.operand_shapes[i].element_type(), + gte.builder()); + gte = gte * scaling_value; + } + result[op_idx] = gte; + } + chained_token = + xla::GetTupleElement(reduce_result, type_ctx.second.indices.size()); + } + result.push_back( + MaybeConvertTo(chained_token, XlaHelpers::TypeOfXlaOp(token))); + return result; +} - return {reduce_result, token_handler.GetNewToken(reduce_result)}; +// moved from torch_xla/csrc/ops/all_reduce.cpp +std::vector GetOperandList( + c10::ArrayRef operands, + const torch::lazy::Value& token) { + std::vector operand_list(operands.begin(), + operands.end()); + operand_list.push_back(token); + return operand_list; } const torch::lazy::Value& GetAllReduceToken( diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index c35560c6b390..4fc7b9a45622 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -6,6 +6,7 @@ #include "absl/types/span.h" #include "torch/csrc/lazy/core/ir.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/ir.h" #include "xla/client/xla_builder.h" namespace torch_xla { @@ -24,11 +25,6 @@ struct AllToAllResult { xla::XlaOp token; }; -struct AllGatherResult { - xla::XlaOp result; - xla::XlaOp token; -}; - struct CollectivePermuteResult { xla::XlaOp result; xla::XlaOp token; @@ -44,11 +40,6 @@ struct RecvResult { xla::XlaOp token; }; -struct ReduceScatterResult { - xla::XlaOp result; - xla::XlaOp token; -}; - std::vector BuildAllReduce( AllReduceType reduce_type, absl::Span operands, xla::XlaOp token, double scale, @@ -60,10 +51,10 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token, const std::vector>& groups, bool pin_layout); -AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim, - int64_t shard_count, - const std::vector>& groups, - bool pin_layout); +std::vector BuildAllGather( + absl::Span, xla::XlaOp token, int64_t dim, + int64_t shard_count, const std::vector>& groups, + bool pin_layout); CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, @@ -75,11 +66,15 @@ SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token, RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape, int64_t channel_id); -ReduceScatterResult BuildReduceScatter( - AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale, - int64_t scatter_dim, int64_t shard_count, +std::vector BuildReduceScatter( + AllReduceType reduce_type, absl::Span inputs, + xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout); +std::vector GetOperandList( + c10::ArrayRef operands, + const torch::lazy::Value& token); + const torch::lazy::Value& GetAllReduceToken( const torch::lazy::BackendDevice& device); void SetAllReduceToken(const torch::lazy::BackendDevice& device, @@ -89,4 +84,4 @@ AllReduceType GetReduceType(c10::string_view reduce_type); } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_CROSS_REPLICA_REDUCES_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_CROSS_REPLICA_REDUCES_H_ diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 421066ba72cd..90381add12da 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -213,6 +213,29 @@ std::pair> ReduceScatter( std::make_shared(new_token)); } +std::pair, std::shared_ptr> +ReduceScatterCoalesced(const std::string& reduce_type, + const std::vector& outputs, + const std::vector& inputs, + const std::shared_ptr& token, + double scale, int64_t scatter_dim, int64_t shard_count, + const std::vector>& replica_groups, + bool pin_layout) { + std::vector xtensors_out = + GetXlaTensors(outputs, /*want_all=*/true); + std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); + std::vector result; + torch::lazy::Value new_token; + std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced( + xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale, + scatter_dim, shard_count, replica_groups, pin_layout); + std::vector aten_result; + for (auto& xt : result) { + aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt))); + } + return {aten_result, std::make_shared(new_token)}; +} + std::shared_ptr ReduceScatterOut( const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, const std::shared_ptr& token, double scale, @@ -235,6 +258,25 @@ at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count, return bridge::AtenFromXlaTensor(std::move(result)); } +std::pair, std::shared_ptr> +AllGatherCoalesced(const std::vector& tensors, + const std::shared_ptr& token, + int64_t dim, int64_t shard_count, + const std::vector>& replica_groups, + bool pin_layout) { + std::vector xtensors = + GetXlaTensors(tensors, /*want_all=*/true); + std::vector result; + torch::lazy::Value new_token; + std::tie(result, new_token) = tensor_methods::all_gather( + xtensors, *token, dim, shard_count, replica_groups, pin_layout); + std::vector aten_result; + for (auto& xt : result) { + aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt))); + } + return {aten_result, std::make_shared(new_token)}; +} + std::shared_ptr AllGatherOut( at::Tensor& output, const at::Tensor& input, const std::shared_ptr& token, int64_t dim, @@ -1151,6 +1193,27 @@ void InitXlaModuleBindings(py::module m) { } return new_token; }); + m.def("_xla_all_gather_coalesced", + [](const std::vector& tensors, + const std::shared_ptr& token, int64_t dim, + int64_t shard_count, const py::list& groups, bool pin_layout) { + std::vector> replica_groups = + CreateReduceGroups(groups); + std::vector result; + std::shared_ptr new_token; + { + NoGilSection nogil; + std::tie(result, new_token) = AllGatherCoalesced( + tensors, token, dim, shard_count, replica_groups, pin_layout); + } + auto result_list = py::list(result.size() + 1); + for (int i = 0; i < result.size(); ++i) { + result_list[i] = torch::autograd::make_variable( + result[i], /*requires_grad=*/result[i].requires_grad()); + } + result_list[result.size()] = new_token; + return result_list; + }); m.def("_xla_collective_permute", [](const at::Tensor& input, const std::shared_ptr& token, @@ -1222,6 +1285,30 @@ void InitXlaModuleBindings(py::module m) { result_tuple[1] = new_token; return result_tuple; }); + m.def("_xla_reduce_scatter_coalesced", + [](const std::string& reduce_type, std::vector& outputs, + const std::vector& inputs, + const std::shared_ptr& token, double scale, + int64_t scatter_dim, int64_t shard_count, const py::list& groups, + bool pin_layout) { + std::vector> replica_groups = + CreateReduceGroups(groups); + std::vector result; + std::shared_ptr new_token; + { + NoGilSection nogil; + std::tie(result, new_token) = ReduceScatterCoalesced( + reduce_type, outputs, inputs, token, scale, scatter_dim, + shard_count, replica_groups, pin_layout); + } + auto result_list = py::list(result.size() + 1); + for (int i = 0; i < result.size(); ++i) { + result_list[i] = torch::autograd::make_variable( + result[i], /*requires_grad=*/result[i].requires_grad()); + } + result_list[result.size()] = new_token; + return result_list; + }); m.def("_xla_reduce_scatter_out", [](const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, diff --git a/torch_xla/csrc/ops/all_gather.cpp b/torch_xla/csrc/ops/all_gather.cpp index 4ea1bf714df8..6f7a0c1d9604 100644 --- a/torch_xla/csrc/ops/all_gather.cpp +++ b/torch_xla/csrc/ops/all_gather.cpp @@ -10,31 +10,37 @@ namespace torch_xla { namespace { -xla::Shape NodeOutputShape(const torch::lazy::Value& input, +xla::Shape NodeOutputShape(c10::ArrayRef inputs, const torch::lazy::Value& token, int64_t dim, int64_t shard_count, const std::vector>& groups, bool pin_layout) { auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { - AllGatherResult result = BuildAllGather(operands[0], operands[1], dim, - shard_count, groups, pin_layout); - return xla::Tuple(operands[0].builder(), {result.result, result.token}); + std::vector result = + BuildAllGather(operands.subspan(0, operands.size() - 1), + operands.back(), dim, shard_count, groups, pin_layout); + return xla::Tuple(operands[0].builder(), result); }; - return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn); + std::vector input_shapes; + for (const auto& input : inputs) { + input_shapes.emplace_back(GetXlaShape(input)); + } + input_shapes.emplace_back(GetXlaShape(token)); + return InferOutputShape(input_shapes, shape_fn); } } // namespace -AllGather::AllGather(const torch::lazy::Value& input, +AllGather::AllGather(c10::ArrayRef inputs, const torch::lazy::Value& token, int64_t dim, int64_t shard_count, std::vector> groups, bool pin_layout) - : XlaNode(xla_all_gather, {input, token}, + : XlaNode(xla_all_gather, GetOperandList(inputs, token), [&]() { - return NodeOutputShape(input, token, dim, shard_count, groups, + return NodeOutputShape(inputs, token, dim, shard_count, groups, pin_layout); }, - /*num_outputs=*/2, + /*num_outputs=*/inputs.size() + 1, torch::lazy::MHash(dim, shard_count, groups, pin_layout)), dim_(dim), shard_count_(shard_count), @@ -42,16 +48,22 @@ AllGather::AllGather(const torch::lazy::Value& input, pin_layout_(pin_layout) {} torch::lazy::NodePtr AllGather::Clone(torch::lazy::OpList operands) const { - return torch::lazy::MakeNode(operands.at(0), operands.at(1), dim_, + std::vector inputs(operands.begin(), operands.end() - 1); + return torch::lazy::MakeNode(inputs, operands.back(), dim_, shard_count_, groups_, pin_layout_); } XlaOpVector AllGather::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp token = loctx->GetOutputOp(operand(1)); - AllGatherResult result = - BuildAllGather(input, token, dim_, shard_count_, groups_, pin_layout_); - return ReturnOps({result.result, result.token}, loctx); + auto& operand_list = operands(); + std::vector inputs; + inputs.reserve(operand_list.size()); + for (size_t i = 0; i + 1 < operand_list.size(); ++i) { + inputs.push_back(loctx->GetOutputOp(operand_list[i])); + } + xla::XlaOp token = loctx->GetOutputOp(operand_list.back()); + return ReturnOps( + BuildAllGather(inputs, token, dim_, shard_count_, groups_, pin_layout_), + loctx); } std::string AllGather::ToString() const { diff --git a/torch_xla/csrc/ops/all_gather.h b/torch_xla/csrc/ops/all_gather.h index c5ade3b1d804..fe57bc44c109 100644 --- a/torch_xla/csrc/ops/all_gather.h +++ b/torch_xla/csrc/ops/all_gather.h @@ -8,8 +8,8 @@ namespace torch_xla { class AllGather : public XlaNode { public: - AllGather(const torch::lazy::Value& input, const torch::lazy::Value& token, - int64_t dim, int64_t shard_count, + AllGather(c10::ArrayRef inputs, + const torch::lazy::Value& token, int64_t dim, int64_t shard_count, std::vector> groups, bool pin_layout); std::string ToString() const override; @@ -35,4 +35,4 @@ class AllGather : public XlaNode { } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_OPS_ALL_GATHER_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_OPS_ALL_GATHER_H_ diff --git a/torch_xla/csrc/ops/all_reduce.cpp b/torch_xla/csrc/ops/all_reduce.cpp index e3fa4f57c9b7..afa25b1ca9dd 100644 --- a/torch_xla/csrc/ops/all_reduce.cpp +++ b/torch_xla/csrc/ops/all_reduce.cpp @@ -22,15 +22,6 @@ xla::Shape NodeOutputShape(c10::ArrayRef operands, return xla::ShapeUtil::MakeTupleShape(tuple_shapes); } -std::vector GetOperandList( - c10::ArrayRef operands, - const torch::lazy::Value& token) { - std::vector operand_list(operands.begin(), - operands.end()); - operand_list.push_back(token); - return operand_list; -} - } // namespace AllReduce::AllReduce(AllReduceType reduce_type, diff --git a/torch_xla/csrc/ops/reduce_scatter.cpp b/torch_xla/csrc/ops/reduce_scatter.cpp index 91c0f5d66e2e..941888939f64 100644 --- a/torch_xla/csrc/ops/reduce_scatter.cpp +++ b/torch_xla/csrc/ops/reduce_scatter.cpp @@ -12,37 +12,40 @@ namespace torch_xla { namespace { xla::Shape NodeOutputShape(AllReduceType reduce_type, - const torch::lazy::Value input, + c10::ArrayRef inputs, const torch::lazy::Value& token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout) { auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { - xla::XlaOp inputOp = operands[0]; - xla::XlaOp tokenOp = operands[1]; - ReduceScatterResult result = - BuildReduceScatter(reduce_type, inputOp, tokenOp, scale, scatter_dim, - shard_count, groups, pin_layout); - return xla::Tuple(operands[0].builder(), {result.result, result.token}); + std::vector result = BuildReduceScatter( + reduce_type, operands.subspan(0, operands.size() - 1), operands.back(), + scale, scatter_dim, shard_count, groups, pin_layout); + return xla::Tuple(operands[0].builder(), result); }; - return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn); + std::vector input_shapes; + for (const auto& input : inputs) { + input_shapes.emplace_back(GetXlaShape(input)); + } + input_shapes.emplace_back(GetXlaShape(token)); + return InferOutputShape(input_shapes, shape_fn); } } // namespace ReduceScatter::ReduceScatter(AllReduceType reduce_type, - const torch::lazy::Value& input, + c10::ArrayRef inputs, const torch::lazy::Value& token, double scale, int64_t scatter_dim, int64_t shard_count, std::vector> groups, bool pin_layout) - : XlaNode(xla_reduce_scatter, {input, token}, + : XlaNode(xla_reduce_scatter, GetOperandList(inputs, token), [&]() { - return NodeOutputShape(reduce_type, input, token, scale, + return NodeOutputShape(reduce_type, inputs, token, scale, scatter_dim, shard_count, groups, pin_layout); }, - /*num_outputs=*/2, + /*num_outputs=*/inputs.size() + 1, torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale, scatter_dim, shard_count, groups, pin_layout)), reduce_type_(reduce_type), @@ -53,18 +56,24 @@ ReduceScatter::ReduceScatter(AllReduceType reduce_type, pin_layout_(pin_layout) {} torch::lazy::NodePtr ReduceScatter::Clone(torch::lazy::OpList operands) const { + std::vector inputs(operands.begin(), operands.end() - 1); return torch::lazy::MakeNode( - reduce_type_, operands.at(0), operands.at(1), scale_, scatter_dim_, - shard_count_, groups_, pin_layout_); + reduce_type_, inputs, operands.back(), scale_, scatter_dim_, shard_count_, + groups_, pin_layout_); } XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp token = loctx->GetOutputOp(operand(1)); - ReduceScatterResult result = - BuildReduceScatter(reduce_type_, input, token, scale_, scatter_dim_, - shard_count_, groups_, pin_layout_); - return ReturnOps({result.result, result.token}, loctx); + auto& operand_list = operands(); + std::vector inputs; + inputs.reserve(operand_list.size()); + for (size_t i = 0; i + 1 < operand_list.size(); ++i) { + inputs.push_back(loctx->GetOutputOp(operand_list[i])); + } + xla::XlaOp token = loctx->GetOutputOp(operand_list.back()); + return ReturnOps( + BuildReduceScatter(reduce_type_, inputs, token, scale_, scatter_dim_, + shard_count_, groups_, pin_layout_), + loctx); } std::string ReduceScatter::ToString() const { diff --git a/torch_xla/csrc/ops/reduce_scatter.h b/torch_xla/csrc/ops/reduce_scatter.h index 0c888ce0fde8..8e4a9e97275e 100644 --- a/torch_xla/csrc/ops/reduce_scatter.h +++ b/torch_xla/csrc/ops/reduce_scatter.h @@ -8,7 +8,8 @@ namespace torch_xla { class ReduceScatter : public XlaNode { public: - ReduceScatter(AllReduceType reduce_type, const torch::lazy::Value& input, + ReduceScatter(AllReduceType reduce_type, + c10::ArrayRef inputs, const torch::lazy::Value& token, double scale, int64_t scatter_dim, int64_t shard_count, std::vector> groups, bool pin_layout); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index e036c0e70778..ea04450d7a5c 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -389,6 +389,33 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, return torch::lazy::Value(node, 1); } +std::pair, torch::lazy::Value> + reduce_scatter_coalesced(const std::vector& outputs, + const std::vector& inputs, + const torch::lazy::Value& token, + AllReduceType reduce_type, double scale, + int64_t scatter_dim, int64_t shard_count, + std::vector> groups, + bool pin_layout) { + XLA_CHECK(outputs.empty() || outputs.size() == inputs.size()); + std::vector input_values; + input_values.reserve(inputs.size()); + for (auto& input : inputs) { + input_values.push_back(input->GetIrValue()); + } + torch::lazy::NodePtr node = torch::lazy::MakeNode( + reduce_type, input_values, token, scale, scatter_dim, shard_count, + std::move(groups), pin_layout); + std::vector result; + for (size_t i = 0; i < inputs.size(); ++i) { + result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i))); + if (!outputs.empty()) { + outputs[i]->SetIrValue(torch::lazy::Value(node, i)); + } + } + return {result, torch::lazy::Value(node, inputs.size())}; +} + std::pair all_to_all( const XLATensorPtr& input, const torch::lazy::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, @@ -400,16 +427,22 @@ std::pair all_to_all( torch::lazy::Value(node, 1)}; } -XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim, - int64_t shard_count, - std::vector> groups, - bool pin_layout) { +std::pair, torch::lazy::Value> all_gather( + const std::vector& inputs, const torch::lazy::Value& token, + int64_t dim, int64_t shard_count, std::vector> groups, + bool pin_layout) { + std::vector input_values; + input_values.reserve(inputs.size()); + for (auto& input : inputs) { + input_values.push_back(input->GetIrValue()); + } torch::lazy::NodePtr node = torch::lazy::MakeNode( - input->GetIrValue(), GetAllReduceToken(input->GetDevice()), dim, - shard_count, std::move(groups), pin_layout); - SetAllReduceToken(input->GetDevice(), - std::make_shared(node, 1)); - return input->CreateFrom(torch::lazy::Value(node, 0)); + input_values, token, dim, shard_count, std::move(groups), pin_layout); + std::vector result; + for (size_t i = 0; i < inputs.size(); ++i) { + result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i))); + } + return {result, torch::lazy::Value(node, inputs.size())}; } torch::lazy::Value all_gather_out(XLATensorPtr& output, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 88d6e8b44965..dc257c858d99 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -33,11 +33,25 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, std::vector> groups, bool pin_layout); +std::pair, torch::lazy::Value> + reduce_scatter_coalesced(const std::vector& outputs, + const std::vector& inputs, + const torch::lazy::Value& token, + AllReduceType reduce_type, double scale, + int64_t scatter_dim, int64_t shard_count, + std::vector> groups, + bool pin_layout); + std::pair all_to_all( const XLATensorPtr& input, const torch::lazy::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, std::vector> groups, bool pin_layout); +std::pair, torch::lazy::Value> all_gather( + const std::vector& inputs, const torch::lazy::Value& token, + int64_t dim, int64_t shard_count, + std::vector> groups, bool pin_layout); + XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim, int64_t shard_count, std::vector> groups, diff --git a/torch_xla/distributed/fsdp/utils.py b/torch_xla/distributed/fsdp/utils.py index ee79bc4a9668..548cdf5b3fa9 100644 --- a/torch_xla/distributed/fsdp/utils.py +++ b/torch_xla/distributed/fsdp/utils.py @@ -59,12 +59,7 @@ def dummy_all_reduce(reduce_type, inputs, scale=1.0, groups=None): return [t.mul_(scale) for t in inputs] -def dummy_reduce_scatter(reduce_type, - input, - scale, - scatter_dim, - shard_count, - groups=None): +class DummyReduceScatter: """A dummy op for debugging with the same output shape as reduce_scatter""" assert shard_count == xm.xrt_world_size() full_size = input.size(scatter_dim) @@ -75,6 +70,64 @@ def dummy_reduce_scatter(reduce_type, slices[scatter_dim] = slice(begin, end) return input[tuple(slices)] * scale + def __init__(self, shard_count): + assert shard_count == xm.xrt_world_size() + self.scale = 1.0 + + def __call__(self, input, callback): + full_size = input.size(0) + shard_size = full_size // xm.xrt_world_size() + begin = shard_size * xm.get_ordinal() + end = begin + shard_size + slices = [None] * input.dim() + slices[0] = slice(begin, end) + callback(input[tuple(slices)]) + + def flush(self): + pass + + +class BucketizedReduceScatter: + """A reduce_scatter op that group input tensors before reduce-scattering them.""" + + def __init__(self, bucket_size_mb, shard_count, groups, pin_layout) -> None: + self.bucket_size_bytes = bucket_size_mb * 1024 * 1024 + self.shard_count = shard_count + self.groups = groups + self.pin_layout = pin_layout + self.scale = 1.0 + + self.callbacks = [] + self.bucket = [] + self.bucket_watermark = 0 + + def __call__(self, input, callback): + input_byte_size = input.element_size() * input.numel() + self.bucket.append(input) + self.callbacks.append(callback) + self.bucket_watermark += input_byte_size + if self.bucket_watermark > self.bucket_size_bytes: + self.flush() + + def flush(self): + if not self.bucket: + return + + results = xm.reduce_scatter( + xm.REDUCE_SUM, + self.bucket, + scale=self.scale, + scatter_dim=0, + shard_count=self.shard_count, + groups=self.groups, + pin_layout=self.pin_layout) + for cb, result in zip(self.callbacks, results): + cb(result) + + self.bucket.clear() + self.callbacks.clear() + self.bucket_watermark = 0 + class XLAPatchedLinear(torch.autograd.Function): """ diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index f1b62d1700b6..506804316cd0 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -35,7 +35,14 @@ import torch_xla.core.xla_model as xm from .xla_flatten_params_wrapper import XlaFlattenParamsWrapper -from .utils import dummy_all_gather, dummy_all_reduce, dummy_reduce_scatter, apply_xla_patch_to_nn_linear +from .utils import ( + BucketizedReduceScatter, + DummyReduceScatter, + dummy_all_gather, + dummy_all_reduce, + apply_xla_patch_to_nn_linear, +) + from .wrap import recursive_wrap from ._init_utils import _materialize_module @@ -295,6 +302,8 @@ def __init__( sharding_world_size: Optional[int] = None, shard_param_on_dim_0: bool = False, pin_layout_in_collective_ops: bool = True, + coalesce_all_gather_ops: bool = False, + reduce_scatter_bucket_size_mb: Optional[int] = 0, auto_wrap_policy: Optional[Callable] = None, auto_wrapper_callable: Optional[Callable] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, @@ -397,6 +406,7 @@ def __init__( # When `_shard_param_on_dim_0` is True, we shard and all-gather model parameter tensors # only along their dim 0 without flattening the parameter self._shard_param_on_dim_0 = shard_param_on_dim_0 and not flatten_parameters + self.coalesce_all_gather_ops = coalesce_all_gather_ops # Set layout pinning to False in all_gather, all_reduce, and reduce_scatter so that they can work together # TODO (ronghanghu): change the default layout pinning to True after it's supported simultaneously # on all collective ops (see https://github.com/pytorch/xla/pull/3511 for details) @@ -411,10 +421,13 @@ def __init__( self.all_reduce_op = functools.partial( xm.all_reduce, pin_layout=pin_layout_in_collective_ops) if _debug_dummy_reduce_scatter_op: - self.reduce_scatter_op = dummy_reduce_scatter + self.reduce_scatter_op = DummyReduceScatter(shard_count=self.world_size) else: - self.reduce_scatter_op = functools.partial( - xm.reduce_scatter, pin_layout=pin_layout_in_collective_ops) + self.reduce_scatter_op = BucketizedReduceScatter( + reduce_scatter_bucket_size_mb, + shard_count=self.world_size, + groups=self.sharding_groups, + pin_layout=pin_layout_in_collective_ops) if _debug_dummy_optimization_barrier_op: self.optimization_barrier_op = lambda *args: None else: @@ -552,6 +565,10 @@ def set_gradient_divide_factors(self, pre: float, post: float, module.set_gradient_divide_factors(pre, post, False) self.gradient_predivide_factor = pre self.gradient_postdivide_factor = post + if (pre, post) == (1, 1): + self.reduce_scatter_op.scale = 1.0 / self.world_size + else: + self.reduce_scatter_op.scale = 1.0 @property def module(self) -> XlaFlattenParamsWrapper: @@ -1142,6 +1159,7 @@ def _register_post_backward_hooks(self) -> None: """ if not torch.is_grad_enabled(): return # don't register grad hooks if grad isn't enabled + self._post_backward_hooks_to_call = 0 for p in self.full_params: if p.requires_grad: if hasattr(p, "_shard_bwd_hook"): @@ -1155,6 +1173,7 @@ def _register_post_backward_hooks(self) -> None: handle = grad_acc.register_hook( functools.partial(self._post_backward_hook, p)) p._shard_bwd_hook = (grad_acc, handle) + self._post_backward_hooks_to_call += 1 @torch.no_grad() def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: @@ -1181,7 +1200,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # then subsequent hook callbacks will see POST state. self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) self.training_state = TrainingState.BACKWARD_POST + self._post_backward_hooks_to_call -= 1 if param.grad is None: + if self._post_backward_hooks_to_call == 0: + self.reduce_scatter_op.flush() return assert param.grad is not None, param.shape @@ -1202,6 +1224,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: apply_opt_barrier=self.optimization_barrier_in_backward) if not self._require_backward_grad_sync: + if self._post_backward_hooks_to_call == 0: + self.reduce_scatter_op.flush() return if self.gradient_predivide_factor > 1: @@ -1217,38 +1241,37 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: self.optimization_barrier_op([grad_flat]) if grad_flat.dtype != torch.float32 and self.fp32_reduce_scatter: grad_flat = grad_flat.to(torch.float32) - reduced_grad = self.reduce_scatter_op( - xm.REDUCE_SUM, - grad_flat.detach(), - scale=1.0, - scatter_dim=0, - shard_count=self.world_size, - groups=self.sharding_groups) - if reduced_grad.dtype != torch.float32: - reduced_grad = reduced_grad.to(torch.float32) - if self.optimization_barrier_in_backward: - self.optimization_barrier_op([reduced_grad]) - if self.gradient_postdivide_factor > 1: - # Average grad by world_size for consistency with PyTorch DDP. - reduced_grad.div_(self.gradient_postdivide_factor) - - grad._has_full_param = True - grad_flat._has_full_param = True - self._free_full_params( - [grad, grad_flat], - dependency_tensors=[reduced_grad], - apply_opt_barrier=self.optimization_barrier_in_backward) - self._try_adding_to_backward_opt_barrier_lists(reduced_grad) - - # Accumulate into the gradient shard. - assert hasattr(param, "_sharded_param") - p_shard = param._sharded_param - if p_shard.grad is None: - p_shard.grad = reduced_grad - else: - assert p_shard.grad.shape == reduced_grad.shape - assert p_shard.grad.device == reduced_grad.device - p_shard.grad += reduced_grad + + def reduce_scatter_done(reduced_grad): + if reduced_grad.dtype != torch.float32: + reduced_grad = reduced_grad.to(torch.float32) + if self.optimization_barrier_in_backward: + self.optimization_barrier_op([reduced_grad]) + if self.gradient_postdivide_factor > 1: + # Average grad by world_size for consistency with PyTorch DDP. + reduced_grad.data.div_(self.gradient_postdivide_factor) + + grad._has_full_param = True + grad_flat._has_full_param = True + self._free_full_params( + [grad, grad_flat], + dependency_tensors=[reduced_grad], + apply_opt_barrier=self.optimization_barrier_in_backward) + self._try_adding_to_backward_opt_barrier_lists(reduced_grad) + + # Accumulate into the gradient shard. + assert hasattr(param, "_sharded_param") + p_shard = param._sharded_param + if p_shard.grad is None: + p_shard.grad = reduced_grad.data + else: + assert p_shard.grad.shape == reduced_grad.shape + assert p_shard.grad.device == reduced_grad.device + p_shard.grad.data += reduced_grad.data + + self.reduce_scatter_op(grad_flat.detach(), reduce_scatter_done) + if self._post_backward_hooks_to_call == 0: + self.reduce_scatter_op.flush() def _queue_wait_for_post_backward(self) -> None: """ @@ -1402,6 +1425,8 @@ def _rebuild_full_params(self, [p for p in self.full_params if p._has_full_param], self.sharded_params, dependency_tensors) + if self.coalesce_all_gather_ops: + p_to_rebuild, shards_to_all_gather = [], [] for p, p_shard in zip(self.full_params, self.sharded_params): if not p._has_full_param: p_shard_data = p_shard @@ -1410,8 +1435,12 @@ def _rebuild_full_params(self, if p_shard_data.dtype != self.compute_dtype: p_shard_data = p_shard_data.to(self.compute_dtype) if self._shard_param_on_dim_0 or self._shard_size_multiple == 1: - p_padded = self.all_gather_op( - p_shard_data, groups=self.sharding_groups) + if self.coalesce_all_gather_ops: + p_to_rebuild.append((p, p_shard)) + shards_to_all_gather.append(p_shard_data) + else: + p_padded = self.all_gather_op( + p_shard_data, groups=self.sharding_groups) else: # gather full parameter from shards # reshape sharded parameters to 2d tensors for efficient gathering on @@ -1419,26 +1448,36 @@ def _rebuild_full_params(self, p_shard_2d = p_shard_data.view(-1, self._shard_size_multiple) p_padded = self.all_gather_op( p_shard_2d, groups=self.sharding_groups).flatten() - if apply_opt_barrier: - self.optimization_barrier_op([p_padded]) - with torch.autograd._unsafe_preserve_version_counter(p): - if self._shard_param_on_dim_0: - if XLA_DISABLE_FUNCTIONALIZATION: - p.data = p_padded[:p_shard._orig_size[ - 0]] # Old behavior before Functionalization. + if not self.coalesce_all_gather_ops: + if apply_opt_barrier: + self.optimization_barrier_op([p_padded]) + with torch.autograd._unsafe_preserve_version_counter(p): + if self._shard_param_on_dim_0: + if XLA_DISABLE_FUNCTIONALIZATION: + p.data = p_padded[:p_shard._orig_size[ + 0]] # Old behavior before Functionalization. + else: + torch_xla._XLAC._replace_xla_tensor( + p, p_padded[:p_shard._orig_size[0]]) else: - torch_xla._XLAC._replace_xla_tensor( - p, p_padded[:p_shard._orig_size[0]]) - else: - if XLA_DISABLE_FUNCTIONALIZATION: - p.data = p_padded[:p_shard._orig_size.numel()].view( - p_shard._orig_size) # Old behavior before Functionalization. - else: - torch_xla._XLAC._replace_xla_tensor( - p, p_padded[:p_shard._orig_size.numel()].view( - p_shard._orig_size)) + if XLA_DISABLE_FUNCTIONALIZATION: + p.data = p_padded[:p_shard._orig_size.numel()].view( + p_shard._orig_size + ) # Old behavior before Functionalization. + else: + torch_xla._XLAC._replace_xla_tensor( + p, p_padded[:p_shard._orig_size.numel()].view( + p_shard._orig_size)) p._has_full_param = True + if self.coalesce_all_gather_ops: + p_padded_list = self.all_gather_op( + shards_to_all_gather, groups=self.sharding_groups) + if apply_opt_barrier: + self.optimization_barrier_op(p_padded_list) + for (p, p_shard), p_padded in zip(p_to_rebuild, p_padded_list): + p.data = p_padded[:p_shard._orig_size[0]] + self.has_full_params = True @torch.no_grad() diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index aa2769cb94dc..9e980997f8d2 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -80,6 +80,14 @@ def allgather(self, output_tensors_list, input_tensors, opts=None): return _ret_work([t for sublist in output_tensors_list for t in sublist]) + def allgather_coalesced(self, output_tensors_list, input_tensors, opts=None): + results = xm.all_gather(input_tensors, groups=self._mesh, pin_layout=False) + for i, result in enumerate(results): + for j, slice in enumerate(torch.split(result, input_tensors[i].shape[0])): + output_tensors_list[i][j].copy_(slice) + + return _ret_work([t for sublist in output_tensors_list for t in sublist]) + # Call site: # https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L1129 def broadcast(self, tensors, opts): @@ -117,6 +125,33 @@ def reduce_scatter(self, output_tensors, input_tensors_list, opts): return _ret_work(output_tensors) + def reduce_scatter_coalesced(self, output_tensors, input_tensors_list, opts): + input_tensor_list = [] + for input_tensors in input_tensors_list: + # Ensure all inputs have the same shape. + first_shape = input_tensors[0].shape + for i, t in enumerate(input_tensors[1:]): + if first_shape != t.shape: + raise ValueError(f"Input {i+1}'s shape is different from input 0: " + f"{t.shape} vs {first_shape}") + input_tensor = torch.cat(input_tensors) + input_tensor_list.append(input_tensor) + + reduce_type = self._get_reduce_type(opts.reduceOp) + groups = self._mesh + shard_count = len(groups[0]) if groups else self.size() + xm.reduce_scatter( + reduce_type, + input_tensor_list, + scatter_dim=0, + shard_count=shard_count, + scale=1, + groups=groups, + output=output_tensors, + pin_layout=False) + + return _ret_work(output_tensors) + # Call site: # https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L2683 def barrier(self, opts): @@ -128,9 +163,6 @@ def barrier(self, opts): def reduce(self, *args): raise NotImplementedError - def allgather_coalesced(self, *args): - raise NotImplementedError - def allreduce_coalesced(self, *args): raise NotImplementedError