diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 478ab4dd498..991e4669d1d 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -2,7 +2,7 @@ import functools import os import re -from unittest import mock +from unittest import mock, skipIf from absl.testing import absltest, parameterized import torch @@ -12,6 +12,15 @@ import torch_xla.distributed.xla_backend from torch_xla import runtime as xr +from datetime import timedelta + + +def get_process_group_xla(rank, size): + pg_xla_creator = dist.Backend._plugins['XLA'].creator_fn + pg_xla = pg_xla_creator( + prefix_store=None, rank=rank, size=size, timeout=timedelta(minutes=1)) + return pg_xla + def hlo_matches(hlo, expected_pattern, match_times=1): matches = re.findall(expected_pattern, hlo) @@ -87,6 +96,25 @@ def test_allgather(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors) hlo_matches(hlo, all_gather_pattern) + @patch_world(rank=3, size=8) + def test_allgather_coalesced(self): + device = xm.xla_device() + tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() + tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() + pg_xla = get_process_group_xla(rank=3, size=8) + output_tensors = [torch.zeros_like(tensor)] * 8 + output_tensors2 = [torch.zeros_like(tensor2)] * 8 + # because we set os.environ[xenv.WORLD_SIZE] = '1', here the outputs' + # shapes will be same as the inputs' shapes. + # Ex: %all-gather.26 = (s64[2]{0}, s64[5]{0}) all-gather(s64[2]{0} %get-tuple-element.24, s64[5]{0} %get-tuple-element.25), replica_groups={}, dimensions={0} + all_gather_pattern = ( + r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}\) ' + r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+\)') + pg_xla.allgather_coalesced([output_tensors, output_tensors2], + [tensor, tensor2]) + hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors) + hlo_matches(hlo, all_gather_pattern) + def test_broadcast(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() @@ -291,7 +319,6 @@ def test_barrier(self): @parameterized.parameters( 'reduce', - 'allgather_coalesced', 'allreduce_coalesced', 'alltoall', 'alltoall_base', diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 7f9682c1000..c080889777c 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -560,17 +560,31 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): shard_count = xrt_world_size() token, devctx = _get_all_reduce_token() - if output != None: - # Call the out of place version of the all_gather - new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim, - shard_count, groups or [], - pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output - result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or - [], pin_layout) - return result + if isinstance(value, torch.Tensor): + if output != None: + # Call the out of place version of the all_gather + new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim, + shard_count, groups or [], + pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output + + result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or + [], pin_layout) + return result + + # Now the input should be a list of Tensors. + elif isinstance(value, list) and all( + isinstance(v, torch.Tensor) for v in value): + result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim, + shard_count, groups or + [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] + else: + raise TypeError("`value` needs to be a Tensor or a list of Tensors, but " + f"given {type(value)}.") def all_to_all(value, diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index c7bcdab4596..a06d916d1e1 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -236,6 +236,45 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim, return {all_gather_result, token_handler.GetNewToken(all_gather_result)}; } +AllGatherResultCoalesced BuildAllGatherCoalesced( + absl::Span inputs, xla::XlaOp token, int64_t dim, + int64_t shard_count, const std::vector>& groups, + bool pin_layout) { + std::vector cc_groups = CreateReduceGroups(groups); + TokenHandler token_handler(token); + // TODO: We use pseudo-tokens ATM, which are real values. This need to be + // switched to use the real XLA Token once support has been added to XLA + // AllGather(). + ReduceContext cc_ctx = GetReduceContext(inputs); + std::vector result(inputs.size()); + + for (auto& type_ctx : cc_ctx.contexts) { + xla::XlaOp all_gather_result; + type_ctx.second.ops[0] = token_handler.GetInput( + type_ctx.second.ops[0], &type_ctx.second.operand_shapes[0]); + if (pin_layout) { + all_gather_result = xla::AllGather( + xla::Tuple(inputs[0].builder(), type_ctx.second.ops), dim, + shard_count, cc_groups, /*channel_id=*/absl::nullopt, + /*layout=*/ + MakeReduceShape(type_ctx.second.operand_shapes).layout()); + } else { + all_gather_result = + xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops), + dim, shard_count, cc_groups); + } + if (ShapeHelper::ShapeOfXlaOp(all_gather_result).rank() == 0) { + for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) { + size_t op_idx = type_ctx.second.indices[i]; + result[op_idx] = xla::GetTupleElement(all_gather_result, i); + } + } else { + result[0] = all_gather_result; + } + } + return {result, token_handler.GetNewToken(result[0])}; +} + CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, const std::vector>& source_target_pairs) { @@ -309,6 +348,15 @@ ReduceScatterResult BuildReduceScatter( return {reduce_result, token_handler.GetNewToken(reduce_result)}; } +std::vector GetOperandListWithToken( + c10::ArrayRef operands, + const torch::lazy::Value& token) { + std::vector operand_list(operands.begin(), + operands.end()); + operand_list.push_back(token); + return operand_list; +} + const torch::lazy::Value& GetAllReduceToken( const torch::lazy::BackendDevice& device) { auto it = g_all_reduce_tokens.find(device.ordinal()); diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index c35560c6b39..715363ea7a5 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -6,6 +6,7 @@ #include "absl/types/span.h" #include "torch/csrc/lazy/core/ir.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/ir.h" #include "xla/client/xla_builder.h" namespace torch_xla { @@ -29,6 +30,11 @@ struct AllGatherResult { xla::XlaOp token; }; +struct AllGatherResultCoalesced { + std::vector result; + xla::XlaOp token; +}; + struct CollectivePermuteResult { xla::XlaOp result; xla::XlaOp token; @@ -65,6 +71,11 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim, const std::vector>& groups, bool pin_layout); +AllGatherResultCoalesced BuildAllGatherCoalesced( + absl::Span inputs, xla::XlaOp token, int64_t dim, + int64_t shard_count, const std::vector>& groups, + bool pin_layout); + CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, const std::vector>& source_target_pairs); @@ -80,6 +91,10 @@ ReduceScatterResult BuildReduceScatter( int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout); +std::vector GetOperandListWithToken( + c10::ArrayRef operands, + const torch::lazy::Value& token); + const torch::lazy::Value& GetAllReduceToken( const torch::lazy::BackendDevice& device); void SetAllReduceToken(const torch::lazy::BackendDevice& device, @@ -89,4 +104,4 @@ AllReduceType GetReduceType(c10::string_view reduce_type); } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_CROSS_REPLICA_REDUCES_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_CROSS_REPLICA_REDUCES_H_ diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2484642aad0..8594f28230f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -233,6 +233,25 @@ at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count, return bridge::AtenFromXlaTensor(std::move(result)); } +std::pair, std::shared_ptr> +AllGatherCoalesced(const std::vector& tensors, + const std::shared_ptr& token, + int64_t dim, int64_t shard_count, + const std::vector>& replica_groups, + bool pin_layout) { + std::vector xtensors = + GetXlaTensors(tensors, /*want_all=*/true); + std::vector result; + torch::lazy::Value new_token; + std::tie(result, new_token) = tensor_methods::all_gather( + xtensors, *token, dim, shard_count, replica_groups, pin_layout); + std::vector aten_result; + for (auto& xt : result) { + aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt))); + } + return {aten_result, std::make_shared(new_token)}; +} + std::shared_ptr AllGatherOut( at::Tensor& output, const at::Tensor& input, const std::shared_ptr& token, int64_t dim, @@ -1145,6 +1164,27 @@ void InitXlaModuleBindings(py::module m) { } return new_token; }); + m.def("_xla_all_gather_coalesced", + [](const std::vector& tensors, + const std::shared_ptr& token, int64_t dim, + int64_t shard_count, const py::list& groups, bool pin_layout) { + std::vector> replica_groups = + CreateReduceGroups(groups); + std::vector results; + std::shared_ptr new_token; + { + NoGilSection nogil; + std::tie(results, new_token) = AllGatherCoalesced( + tensors, token, dim, shard_count, replica_groups, pin_layout); + } + auto result_list = py::list(results.size() + 1); + for (int i = 0; i < results.size(); ++i) { + result_list[i] = torch::autograd::make_variable( + results[i], /*requires_grad=*/results[i].requires_grad()); + } + result_list[results.size()] = new_token; + return result_list; + }); m.def("_xla_collective_permute", [](const at::Tensor& input, const std::shared_ptr& token, diff --git a/torch_xla/csrc/ops/all_gather.cpp b/torch_xla/csrc/ops/all_gather.cpp index 4ea1bf714df..5bbd2bf2464 100644 --- a/torch_xla/csrc/ops/all_gather.cpp +++ b/torch_xla/csrc/ops/all_gather.cpp @@ -23,6 +23,25 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input, return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn); } +xla::Shape NodeOutputShapeCoalesced( + c10::ArrayRef inputs, const torch::lazy::Value& token, + int64_t dim, int64_t shard_count, + const std::vector>& groups, bool pin_layout) { + auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { + AllGatherResultCoalesced result = BuildAllGatherCoalesced( + operands.subspan(0, operands.size() - 1), operands.back(), dim, + shard_count, groups, pin_layout); + result.result.emplace_back(result.token); + return xla::Tuple(operands[0].builder(), result.result); + }; + std::vector input_shapes; + for (const auto& input : inputs) { + input_shapes.emplace_back(GetXlaShape(input)); + } + input_shapes.emplace_back(GetXlaShape(token)); + return InferOutputShape(input_shapes, shape_fn); +} + } // namespace AllGather::AllGather(const torch::lazy::Value& input, @@ -41,11 +60,35 @@ AllGather::AllGather(const torch::lazy::Value& input, groups_(std::move(groups)), pin_layout_(pin_layout) {} +AllGatherCoalesced::AllGatherCoalesced(c10::ArrayRef inputs, + const torch::lazy::Value& token, + int64_t dim, int64_t shard_count, + std::vector> groups, + bool pin_layout) + : XlaNode(xla_all_gather, GetOperandListWithToken(inputs, token), + [&]() { + return NodeOutputShapeCoalesced(inputs, token, dim, shard_count, + groups, pin_layout); + }, + /*num_outputs=*/inputs.size() + 1, + torch::lazy::MHash(dim, shard_count, groups, pin_layout)), + dim_(dim), + shard_count_(shard_count), + groups_(std::move(groups)), + pin_layout_(pin_layout) {} + torch::lazy::NodePtr AllGather::Clone(torch::lazy::OpList operands) const { return torch::lazy::MakeNode(operands.at(0), operands.at(1), dim_, shard_count_, groups_, pin_layout_); } +torch::lazy::NodePtr AllGatherCoalesced::Clone( + torch::lazy::OpList operands) const { + std::vector inputs(operands.begin(), operands.end() - 1); + return torch::lazy::MakeNode( + inputs, operands.back(), dim_, shard_count_, groups_, pin_layout_); +} + XlaOpVector AllGather::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); xla::XlaOp token = loctx->GetOutputOp(operand(1)); @@ -54,6 +97,20 @@ XlaOpVector AllGather::Lower(LoweringContext* loctx) const { return ReturnOps({result.result, result.token}, loctx); } +XlaOpVector AllGatherCoalesced::Lower(LoweringContext* loctx) const { + auto& operand_list = operands(); + std::vector inputs; + inputs.reserve(operand_list.size()); + for (size_t i = 0; i + 1 < operand_list.size(); ++i) { + inputs.push_back(loctx->GetOutputOp(operand_list[i])); + } + xla::XlaOp token = loctx->GetOutputOp(operand_list.back()); + AllGatherResultCoalesced result = BuildAllGatherCoalesced( + inputs, token, dim_, shard_count_, groups_, pin_layout_); + result.result.push_back(result.token); + return ReturnOps(result.result, loctx); +} + std::string AllGather::ToString() const { std::stringstream ss; ss << XlaNode::ToString() << ", dim=" << dim_ @@ -67,4 +124,17 @@ std::string AllGather::ToString() const { return ss.str(); } +std::string AllGatherCoalesced::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", dim=" << dim_ + << ", shard_count=" << shard_count_ << ", pin_layout=" << pin_layout_ + << ", groups=("; + for (size_t i = 0; i < groups_.size(); ++i) { + ss << (i == 0 ? "(" : ",("); + ss << absl::StrJoin(groups_[i], ", ") << ")"; + } + ss << ")"; + return ss.str(); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/all_gather.h b/torch_xla/csrc/ops/all_gather.h index c5ade3b1d80..78707656d1a 100644 --- a/torch_xla/csrc/ops/all_gather.h +++ b/torch_xla/csrc/ops/all_gather.h @@ -33,6 +33,34 @@ class AllGather : public XlaNode { bool pin_layout_; }; +class AllGatherCoalesced : public XlaNode { + public: + AllGatherCoalesced(c10::ArrayRef inputs, + const torch::lazy::Value& token, int64_t dim, + int64_t shard_count, + std::vector> groups, bool pin_layout); + + std::string ToString() const override; + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + int64_t dim() const { return dim_; } + + int64_t shard_count() const { return shard_count_; } + + const std::vector>& groups() const { return groups_; } + + bool pin_layout() const { return pin_layout_; } + + private: + int64_t dim_; + int64_t shard_count_; + std::vector> groups_; + bool pin_layout_; +}; + } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_OPS_ALL_GATHER_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_OPS_ALL_GATHER_H_ diff --git a/torch_xla/csrc/ops/all_reduce.cpp b/torch_xla/csrc/ops/all_reduce.cpp index e3fa4f57c9b..c68d26b1e1b 100644 --- a/torch_xla/csrc/ops/all_reduce.cpp +++ b/torch_xla/csrc/ops/all_reduce.cpp @@ -22,22 +22,13 @@ xla::Shape NodeOutputShape(c10::ArrayRef operands, return xla::ShapeUtil::MakeTupleShape(tuple_shapes); } -std::vector GetOperandList( - c10::ArrayRef operands, - const torch::lazy::Value& token) { - std::vector operand_list(operands.begin(), - operands.end()); - operand_list.push_back(token); - return operand_list; -} - } // namespace AllReduce::AllReduce(AllReduceType reduce_type, c10::ArrayRef operands, const torch::lazy::Value& token, double scale, std::vector> groups, bool pin_layout) - : XlaNode(xla_cross_replica_sum, GetOperandList(operands, token), + : XlaNode(xla_cross_replica_sum, GetOperandListWithToken(operands, token), [&]() { return NodeOutputShape(operands, token); }, /*num_outputs=*/operands.size() + 1, torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 81073d22fb2..f4bf90be4a2 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -402,6 +402,24 @@ std::pair all_to_all( torch::lazy::Value(node, 1)}; } +std::pair, torch::lazy::Value> all_gather( + const std::vector& inputs, const torch::lazy::Value& token, + int64_t dim, int64_t shard_count, std::vector> groups, + bool pin_layout) { + std::vector input_values; + input_values.reserve(inputs.size()); + for (auto& input : inputs) { + input_values.push_back(input->GetIrValue()); + } + torch::lazy::NodePtr node = torch::lazy::MakeNode( + input_values, token, dim, shard_count, std::move(groups), pin_layout); + std::vector result; + for (size_t i = 0; i < inputs.size(); ++i) { + result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i))); + } + return {result, torch::lazy::Value(node, inputs.size())}; +} + XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim, int64_t shard_count, std::vector> groups, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 30c1b0eca3c..330b185e206 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -38,6 +38,11 @@ std::pair all_to_all( int64_t split_dimension, int64_t concat_dimension, int64_t split_count, std::vector> groups, bool pin_layout); +std::pair, torch::lazy::Value> all_gather( + const std::vector& inputs, const torch::lazy::Value& token, + int64_t dim, int64_t shard_count, std::vector> groups, + bool pin_layout); + XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim, int64_t shard_count, std::vector> groups, diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index f1b62d1700b..dae259b6fb7 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -295,6 +295,7 @@ def __init__( sharding_world_size: Optional[int] = None, shard_param_on_dim_0: bool = False, pin_layout_in_collective_ops: bool = True, + coalesce_all_gather_ops: bool = False, auto_wrap_policy: Optional[Callable] = None, auto_wrapper_callable: Optional[Callable] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, @@ -397,6 +398,7 @@ def __init__( # When `_shard_param_on_dim_0` is True, we shard and all-gather model parameter tensors # only along their dim 0 without flattening the parameter self._shard_param_on_dim_0 = shard_param_on_dim_0 and not flatten_parameters + self.coalesce_all_gather_ops = coalesce_all_gather_ops # Set layout pinning to False in all_gather, all_reduce, and reduce_scatter so that they can work together # TODO (ronghanghu): change the default layout pinning to True after it's supported simultaneously # on all collective ops (see https://github.com/pytorch/xla/pull/3511 for details) @@ -1402,6 +1404,8 @@ def _rebuild_full_params(self, [p for p in self.full_params if p._has_full_param], self.sharded_params, dependency_tensors) + if self.coalesce_all_gather_ops: + p_to_rebuild, shards_to_all_gather = [], [] for p, p_shard in zip(self.full_params, self.sharded_params): if not p._has_full_param: p_shard_data = p_shard @@ -1410,8 +1414,12 @@ def _rebuild_full_params(self, if p_shard_data.dtype != self.compute_dtype: p_shard_data = p_shard_data.to(self.compute_dtype) if self._shard_param_on_dim_0 or self._shard_size_multiple == 1: - p_padded = self.all_gather_op( - p_shard_data, groups=self.sharding_groups) + if self.coalesce_all_gather_ops: + p_to_rebuild.append((p, p_shard)) + shards_to_all_gather.append(p_shard_data) + else: + p_padded = self.all_gather_op( + p_shard_data, groups=self.sharding_groups) else: # gather full parameter from shards # reshape sharded parameters to 2d tensors for efficient gathering on @@ -1419,26 +1427,36 @@ def _rebuild_full_params(self, p_shard_2d = p_shard_data.view(-1, self._shard_size_multiple) p_padded = self.all_gather_op( p_shard_2d, groups=self.sharding_groups).flatten() - if apply_opt_barrier: - self.optimization_barrier_op([p_padded]) - with torch.autograd._unsafe_preserve_version_counter(p): - if self._shard_param_on_dim_0: - if XLA_DISABLE_FUNCTIONALIZATION: - p.data = p_padded[:p_shard._orig_size[ - 0]] # Old behavior before Functionalization. + if not self.coalesce_all_gather_ops: + if apply_opt_barrier: + self.optimization_barrier_op([p_padded]) + with torch.autograd._unsafe_preserve_version_counter(p): + if self._shard_param_on_dim_0: + if XLA_DISABLE_FUNCTIONALIZATION: + p.data = p_padded[:p_shard._orig_size[ + 0]] # Old behavior before Functionalization. + else: + torch_xla._XLAC._replace_xla_tensor( + p, p_padded[:p_shard._orig_size[0]]) else: - torch_xla._XLAC._replace_xla_tensor( - p, p_padded[:p_shard._orig_size[0]]) - else: - if XLA_DISABLE_FUNCTIONALIZATION: - p.data = p_padded[:p_shard._orig_size.numel()].view( - p_shard._orig_size) # Old behavior before Functionalization. - else: - torch_xla._XLAC._replace_xla_tensor( - p, p_padded[:p_shard._orig_size.numel()].view( - p_shard._orig_size)) + if XLA_DISABLE_FUNCTIONALIZATION: + p.data = p_padded[:p_shard._orig_size.numel()].view( + p_shard._orig_size + ) # Old behavior before Functionalization. + else: + torch_xla._XLAC._replace_xla_tensor( + p, p_padded[:p_shard._orig_size.numel()].view( + p_shard._orig_size)) p._has_full_param = True + if self.coalesce_all_gather_ops: + p_padded_list = self.all_gather_op( + shards_to_all_gather, groups=self.sharding_groups) + if apply_opt_barrier: + self.optimization_barrier_op(p_padded_list) + for (p, p_shard), p_padded in zip(p_to_rebuild, p_padded_list): + p.data = p_padded[:p_shard._orig_size[0]] + self.has_full_params = True @torch.no_grad() diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index d448b09dd84..75b909cb848 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -77,6 +77,14 @@ def allgather(self, output_tensors_list, input_tensors, opts=None): return _ret_work([t for sublist in output_tensors_list for t in sublist]) + def allgather_coalesced(self, output_tensors_list, input_tensors, opts=None): + results = xm.all_gather(input_tensors, groups=self._mesh, pin_layout=False) + for i, result in enumerate(results): + for j, slice in enumerate(torch.split(result, input_tensors[i].shape[0])): + output_tensors_list[i][j].copy_(slice) + + return _ret_work([t for sublist in output_tensors_list for t in sublist]) + # Call site: # https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L1129 def broadcast(self, tensors, opts): @@ -125,9 +133,6 @@ def barrier(self, opts): def reduce(self, *args): raise NotImplementedError - def allgather_coalesced(self, *args): - raise NotImplementedError - def allreduce_coalesced(self, *args): raise NotImplementedError