Skip to content

Commit

Permalink
Add all-gather coalescing for FSDP/ZeRO1 (#5950)
Browse files Browse the repository at this point in the history
* Add all-gather and reduce-scatter coalescence support for FSDP.

Also allow using reduce-scatter's scale param in FSDP.
(revived #4145)

* clang-format-7 and python lint fixes

* Fix "SyntaxError: 'return' outside function" error

* Code/test fixes to get run_tests.sh to run on CPU

* Fix allgather to be compatible with openxla allgather tuple change without token

* Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token

* Separate out the reduce-scatter-coalesce changes into a separate PR

* Some cleanups

* Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class

* Use token_handler.GetInput to capture token

* Clean up

* Clean up

* Switch to GetOperandListWithToken naming for func GetOperandList
  • Loading branch information
jeffhataws authored Dec 2, 2023
1 parent 2c4983d commit 1271964
Show file tree
Hide file tree
Showing 12 changed files with 325 additions and 46 deletions.
31 changes: 29 additions & 2 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools
import os
import re
from unittest import mock
from unittest import mock, skipIf

from absl.testing import absltest, parameterized
import torch
Expand All @@ -12,6 +12,15 @@
import torch_xla.distributed.xla_backend
from torch_xla import runtime as xr

from datetime import timedelta


def get_process_group_xla(rank, size):
pg_xla_creator = dist.Backend._plugins['XLA'].creator_fn
pg_xla = pg_xla_creator(
prefix_store=None, rank=rank, size=size, timeout=timedelta(minutes=1))
return pg_xla


def hlo_matches(hlo, expected_pattern, match_times=1):
matches = re.findall(expected_pattern, hlo)
Expand Down Expand Up @@ -87,6 +96,25 @@ def test_allgather(self):
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)

@patch_world(rank=3, size=8)
def test_allgather_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
pg_xla = get_process_group_xla(rank=3, size=8)
output_tensors = [torch.zeros_like(tensor)] * 8
output_tensors2 = [torch.zeros_like(tensor2)] * 8
# because we set os.environ[xenv.WORLD_SIZE] = '1', here the outputs'
# shapes will be same as the inputs' shapes.
# Ex: %all-gather.26 = (s64[2]{0}, s64[5]{0}) all-gather(s64[2]{0} %get-tuple-element.24, s64[5]{0} %get-tuple-element.25), replica_groups={}, dimensions={0}
all_gather_pattern = (
r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}\) '
r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+\)')
pg_xla.allgather_coalesced([output_tensors, output_tensors2],
[tensor, tensor2])
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)

def test_broadcast(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
Expand Down Expand Up @@ -291,7 +319,6 @@ def test_barrier(self):

@parameterized.parameters(
'reduce',
'allgather_coalesced',
'allreduce_coalesced',
'alltoall',
'alltoall_base',
Expand Down
34 changes: 24 additions & 10 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,17 +570,31 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
shard_count = xrt_world_size()

token, devctx = _get_all_reduce_token()
if output != None:
# Call the out of place version of the all_gather
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
[], pin_layout)
return result
if isinstance(value, torch.Tensor):
if output != None:
# Call the out of place version of the all_gather
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
[], pin_layout)
return result

# Now the input should be a list of Tensors.
elif isinstance(value, list) and all(
isinstance(v, torch.Tensor) for v in value):
result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]
else:
raise TypeError("`value` needs to be a Tensor or a list of Tensors, but "
f"given {type(value)}.")


def all_to_all(value,
Expand Down
48 changes: 48 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,45 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
return {all_gather_result, token_handler.GetNewToken(all_gather_result)};
}

AllGatherResultCoalesced BuildAllGatherCoalesced(
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token, int64_t dim,
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
std::vector<xla::ReplicaGroup> cc_groups = CreateReduceGroups(groups);
TokenHandler token_handler(token);
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
// switched to use the real XLA Token once support has been added to XLA
// AllGather().
ReduceContext cc_ctx = GetReduceContext(inputs);
std::vector<xla::XlaOp> result(inputs.size());

for (auto& type_ctx : cc_ctx.contexts) {
xla::XlaOp all_gather_result;
type_ctx.second.ops[0] = token_handler.GetInput(
type_ctx.second.ops[0], &type_ctx.second.operand_shapes[0]);
if (pin_layout) {
all_gather_result = xla::AllGather(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops), dim,
shard_count, cc_groups, /*channel_id=*/absl::nullopt,
/*layout=*/
MakeReduceShape(type_ctx.second.operand_shapes).layout());
} else {
all_gather_result =
xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
dim, shard_count, cc_groups);
}
if (ShapeHelper::ShapeOfXlaOp(all_gather_result).rank() == 0) {
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
result[op_idx] = xla::GetTupleElement(all_gather_result, i);
}
} else {
result[0] = all_gather_result;
}
}
return {result, token_handler.GetNewToken(result[0])};
}

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
Expand Down Expand Up @@ -309,6 +348,15 @@ ReduceScatterResult BuildReduceScatter(
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

std::vector<torch::lazy::Value> GetOperandListWithToken(
c10::ArrayRef<torch::lazy::Value> operands,
const torch::lazy::Value& token) {
std::vector<torch::lazy::Value> operand_list(operands.begin(),
operands.end());
operand_list.push_back(token);
return operand_list;
}

const torch::lazy::Value& GetAllReduceToken(
const torch::lazy::BackendDevice& device) {
auto it = g_all_reduce_tokens.find(device.ordinal());
Expand Down
17 changes: 16 additions & 1 deletion torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "absl/types/span.h"
#include "torch/csrc/lazy/core/ir.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir.h"
#include "xla/client/xla_builder.h"

namespace torch_xla {
Expand All @@ -29,6 +30,11 @@ struct AllGatherResult {
xla::XlaOp token;
};

struct AllGatherResultCoalesced {
std::vector<xla::XlaOp> result;
xla::XlaOp token;
};

struct CollectivePermuteResult {
xla::XlaOp result;
xla::XlaOp token;
Expand Down Expand Up @@ -65,6 +71,11 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout);

AllGatherResultCoalesced BuildAllGatherCoalesced(
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token, int64_t dim,
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
bool pin_layout);

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);
Expand All @@ -80,6 +91,10 @@ ReduceScatterResult BuildReduceScatter(
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

std::vector<torch::lazy::Value> GetOperandListWithToken(
c10::ArrayRef<torch::lazy::Value> operands,
const torch::lazy::Value& token);

const torch::lazy::Value& GetAllReduceToken(
const torch::lazy::BackendDevice& device);
void SetAllReduceToken(const torch::lazy::BackendDevice& device,
Expand All @@ -89,4 +104,4 @@ AllReduceType GetReduceType(c10::string_view reduce_type);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_CROSS_REPLICA_REDUCES_H_
#endif // XLA_TORCH_XLA_CSRC_CROSS_REPLICA_REDUCES_H_
40 changes: 40 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,25 @@ at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count,
return bridge::AtenFromXlaTensor(std::move(result));
}

std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
AllGatherCoalesced(const std::vector<at::Tensor>& tensors,
const std::shared_ptr<torch::lazy::Value>& token,
int64_t dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::all_gather(
xtensors, *token, dim, shard_count, replica_groups, pin_layout);
std::vector<at::Tensor> aten_result;
for (auto& xt : result) {
aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt)));
}
return {aten_result, std::make_shared<torch::lazy::Value>(new_token)};
}

std::shared_ptr<torch::lazy::Value> AllGatherOut(
at::Tensor& output, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
Expand Down Expand Up @@ -1209,6 +1228,27 @@ void InitXlaModuleBindings(py::module m) {
}
return new_token;
});
m.def("_xla_all_gather_coalesced",
[](const std::vector<at::Tensor>& tensors,
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
int64_t shard_count, const py::list& groups, bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::vector<at::Tensor> results;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
std::tie(results, new_token) = AllGatherCoalesced(
tensors, token, dim, shard_count, replica_groups, pin_layout);
}
auto result_list = py::list(results.size() + 1);
for (int i = 0; i < results.size(); ++i) {
result_list[i] = torch::autograd::make_variable(
results[i], /*requires_grad=*/results[i].requires_grad());
}
result_list[results.size()] = new_token;
return result_list;
});
m.def("_xla_collective_permute",
[](const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token,
Expand Down
70 changes: 70 additions & 0 deletions torch_xla/csrc/ops/all_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input,
return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn);
}

xla::Shape NodeOutputShapeCoalesced(
c10::ArrayRef<torch::lazy::Value> inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
AllGatherResultCoalesced result = BuildAllGatherCoalesced(
operands.subspan(0, operands.size() - 1), operands.back(), dim,
shard_count, groups, pin_layout);
result.result.emplace_back(result.token);
return xla::Tuple(operands[0].builder(), result.result);
};
std::vector<xla::Shape> input_shapes;
for (const auto& input : inputs) {
input_shapes.emplace_back(GetXlaShape(input));
}
input_shapes.emplace_back(GetXlaShape(token));
return InferOutputShape(input_shapes, shape_fn);
}

} // namespace

AllGather::AllGather(const torch::lazy::Value& input,
Expand All @@ -41,11 +60,35 @@ AllGather::AllGather(const torch::lazy::Value& input,
groups_(std::move(groups)),
pin_layout_(pin_layout) {}

AllGatherCoalesced::AllGatherCoalesced(c10::ArrayRef<torch::lazy::Value> inputs,
const torch::lazy::Value& token,
int64_t dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout)
: XlaNode(xla_all_gather, GetOperandListWithToken(inputs, token),
[&]() {
return NodeOutputShapeCoalesced(inputs, token, dim, shard_count,
groups, pin_layout);
},
/*num_outputs=*/inputs.size() + 1,
torch::lazy::MHash(dim, shard_count, groups, pin_layout)),
dim_(dim),
shard_count_(shard_count),
groups_(std::move(groups)),
pin_layout_(pin_layout) {}

torch::lazy::NodePtr AllGather::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<AllGather>(operands.at(0), operands.at(1), dim_,
shard_count_, groups_, pin_layout_);
}

torch::lazy::NodePtr AllGatherCoalesced::Clone(
torch::lazy::OpList operands) const {
std::vector<torch::lazy::Value> inputs(operands.begin(), operands.end() - 1);
return torch::lazy::MakeNode<AllGatherCoalesced>(
inputs, operands.back(), dim_, shard_count_, groups_, pin_layout_);
}

XlaOpVector AllGather::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp token = loctx->GetOutputOp(operand(1));
Expand All @@ -54,6 +97,20 @@ XlaOpVector AllGather::Lower(LoweringContext* loctx) const {
return ReturnOps({result.result, result.token}, loctx);
}

XlaOpVector AllGatherCoalesced::Lower(LoweringContext* loctx) const {
auto& operand_list = operands();
std::vector<xla::XlaOp> inputs;
inputs.reserve(operand_list.size());
for (size_t i = 0; i + 1 < operand_list.size(); ++i) {
inputs.push_back(loctx->GetOutputOp(operand_list[i]));
}
xla::XlaOp token = loctx->GetOutputOp(operand_list.back());
AllGatherResultCoalesced result = BuildAllGatherCoalesced(
inputs, token, dim_, shard_count_, groups_, pin_layout_);
result.result.push_back(result.token);
return ReturnOps(result.result, loctx);
}

std::string AllGather::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", dim=" << dim_
Expand All @@ -67,4 +124,17 @@ std::string AllGather::ToString() const {
return ss.str();
}

std::string AllGatherCoalesced::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", dim=" << dim_
<< ", shard_count=" << shard_count_ << ", pin_layout=" << pin_layout_
<< ", groups=(";
for (size_t i = 0; i < groups_.size(); ++i) {
ss << (i == 0 ? "(" : ",(");
ss << absl::StrJoin(groups_[i], ", ") << ")";
}
ss << ")";
return ss.str();
}

} // namespace torch_xla
Loading

0 comments on commit 1271964

Please sign in to comment.