Skip to content

Commit

Permalink
Add reduce-scatter coalescing for FSDP/ZeRO1 (#5956)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws authored and bhavya01 committed Apr 22, 2024
1 parent 3cd2516 commit ae438f3
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 19 deletions.
30 changes: 29 additions & 1 deletion test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def _mp_fn(index):
scale = 1 / world_size
scatter_dim = 1
shard_size = 2
input_list_size = 5

if xm.xla_device_hw(device) in ['TPU', 'CUDA']:
rand = torch.rand((32, shard_size * world_size, 32))
Expand All @@ -25,8 +26,35 @@ def _mp_fn(index):
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)

assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter')

# Testing reduce-scatter with list input
rand_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xrand_list = [rand.to(device) for rand in rand_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
pin_layout=False)

for i, res in enumerate(res_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input')

else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
Expand Down
21 changes: 21 additions & 0 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,27 @@ def test_reduce_scatter(self):
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
hlo_matches(hlo, reduce_scatter_pattern)

@skipIf(xr.device_type() == 'CPU',
"UNIMPLEMENTED: ReduceScatter is not implemented on CPU.")
def test_reduce_scatter_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
input_tensors_list = [[tensor, tensor], [tensor2, tensor2]]
output_list = [torch.zeros_like(tensor), torch.zeros_like(tensor2)]
pg_xla = get_process_group_xla(rank=0, size=len(input_tensors_list[0]))
opts = dist.ReduceScatterOptions()
opts.reduceOp = dist.ReduceOp.SUM
reduce_scatter_pattern = (
r'%reduce\-scatter\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) '
r'reduce\-scatter\(s64\[4]\{0} %.+\.\d+, s64\[10]\{0} %.+\.\d+, '
r's64\[] %.+\.\d+\)')
pg_xla.reduce_scatter_coalesced(output_list, input_tensors_list, opts)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list)
hlo_matches(hlo, reduce_scatter_pattern)
# purge all computations attached the device.
xm.mark_step()

@patch_world(0, 6)
def test_send(self):
device = xm.xla_device()
Expand Down
56 changes: 38 additions & 18 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,16 +740,18 @@ def reduce_scatter(reduce_type,
reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
``xm.REDUCE_MAX``.
input: A single `torch.Tensor` all reduce + scatter op to.
input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
it will also be the output.
scale (float): A default scaling value to be applied after the reduce.
scatter_dim (int): Dimension number to which apply scatter operation.
shard_count (int): The number of ways to split up the scatter_dim in.
groups (list): A list of list, representing the replica groups for
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output: Optional output tensor
output: Optional output tensor if `input` is a torch.Tensor, or a list of
torch.Tensor if `input` is a list of torch.Tensor.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
Expand All @@ -762,21 +764,39 @@ def reduce_scatter(reduce_type,
the same as the input.
"""
token, devctx = _get_all_reduce_token()
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(reduce_type, output,
input, token, scale,
scatter_dim,
shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale,
scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]

if isinstance(input, torch.Tensor):
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
scale, scatter_dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]

# Now the input should be a list of Tensors.
elif isinstance(input, list) and all(
isinstance(v, torch.Tensor) for v in input):
if output != None:
raise RuntimeError(
"For xm.reduce_scatter with list of tensors input, output != None is not yet supported."
)

result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, output or [], input, token, scale, scatter_dim,
shard_count, groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]
else:
raise TypeError("`input` needs to be a Tensor or a list of Tensors, but "
f"given {type(input)}.")


def add_step_closure(closure, args=(), run_async=False):
Expand Down
48 changes: 48 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,54 @@ ReduceScatterResult BuildReduceScatter(
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
std::vector<xla::ReplicaGroup> cc_groups = CreateReduceGroups(groups);
TokenHandler token_handler(token);
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
// switched to use the real XLA Token once support has been added to XLA
// ReduceScatter().
ReduceContext cc_ctx = GetReduceContext(inputs);
std::vector<xla::XlaOp> result(inputs.size());
for (auto& type_ctx : cc_ctx.contexts) {
xla::XlaOp reduce_result;
type_ctx.second.ops[0] = token_handler.GetInput(
type_ctx.second.ops[0], &type_ctx.second.operand_shapes[0]);
if (pin_layout) {
reduce_result = xla::ReduceScatter(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
shard_count, cc_groups, /*channel_id=*/absl::nullopt,
/*layout=*/
MakeReduceShape(type_ctx.second.operand_shapes).layout());
} else {
reduce_result = xla::ReduceScatter(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
shard_count, cc_groups);
}
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
xla::XlaOp gte;
if (ShapeHelper::ShapeOfXlaOp(reduce_result).rank() == 0) {
gte = xla::GetTupleElement(reduce_result, i);
} else {
gte = reduce_result;
}
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, type_ctx.second.operand_shapes[i].element_type(),
gte.builder());
gte = gte * scaling_value;
}
result[op_idx] = gte;
}
}
return {result, token_handler.GetNewToken(result[0])};
}

std::vector<torch::lazy::Value> GetOperandListWithToken(
c10::ArrayRef<torch::lazy::Value> operands,
const torch::lazy::Value& token) {
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ struct ReduceScatterResult {
xla::XlaOp token;
};

struct ReduceScatterResultCoalesced {
std::vector<xla::XlaOp> result;
xla::XlaOp token;
};

std::vector<xla::XlaOp> BuildAllReduce(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
xla::XlaOp token, double scale,
Expand Down Expand Up @@ -91,6 +96,11 @@ ReduceScatterResult BuildReduceScatter(
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

std::vector<torch::lazy::Value> GetOperandListWithToken(
c10::ArrayRef<torch::lazy::Value> operands,
const torch::lazy::Value& token);
Expand Down
47 changes: 47 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,29 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> ReduceScatter(
std::make_shared<torch::lazy::Value>(new_token));
}

std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
ReduceScatterCoalesced(const std::string& reduce_type,
const std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token,
double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced(
xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale,
scatter_dim, shard_count, replica_groups, pin_layout);
std::vector<at::Tensor> aten_result;
for (auto& xt : result) {
aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt)));
}
return {aten_result, std::make_shared<torch::lazy::Value>(new_token)};
}

std::shared_ptr<torch::lazy::Value> ReduceScatterOut(
const std::string& reduce_type, at::Tensor& output, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
Expand Down Expand Up @@ -1320,6 +1343,30 @@ void InitXlaModuleBindings(py::module m) {
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_reduce_scatter_coalesced",
[](const std::string& reduce_type, std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::vector<at::Tensor> result;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) = ReduceScatterCoalesced(
reduce_type, outputs, inputs, token, scale, scatter_dim,
shard_count, replica_groups, pin_layout);
}
auto result_list = py::list(result.size() + 1);
for (int i = 0; i < result.size(); ++i) {
result_list[i] = torch::autograd::make_variable(
result[i], /*requires_grad=*/result[i].requires_grad());
}
result_list[result.size()] = new_token;
return result_list;
});
m.def("_xla_reduce_scatter_out",
[](const std::string& reduce_type, at::Tensor& output,
const at::Tensor& input,
Expand Down
78 changes: 78 additions & 0 deletions torch_xla/csrc/ops/reduce_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,26 @@ xla::Shape NodeOutputShape(AllReduceType reduce_type,
return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn);
}

xla::Shape NodeOutputShapeCoalesced(
AllReduceType reduce_type, c10::ArrayRef<torch::lazy::Value> inputs,
const torch::lazy::Value& token, double scale, int64_t scatter_dim,
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
ReduceScatterResultCoalesced result = BuildReduceScatterCoalesced(
reduce_type, operands.subspan(0, operands.size() - 1), operands.back(),
scale, scatter_dim, shard_count, groups, pin_layout);
result.result.emplace_back(result.token);
return xla::Tuple(operands[0].builder(), result.result);
};
std::vector<xla::Shape> input_shapes;
for (const auto& input : inputs) {
input_shapes.emplace_back(GetXlaShape(input));
}
input_shapes.emplace_back(GetXlaShape(token));
return InferOutputShape(input_shapes, shape_fn);
}

} // namespace

ReduceScatter::ReduceScatter(AllReduceType reduce_type,
Expand All @@ -52,12 +72,41 @@ ReduceScatter::ReduceScatter(AllReduceType reduce_type,
groups_(std::move(groups)),
pin_layout_(pin_layout) {}

ReduceScatterCoalesced::ReduceScatterCoalesced(
AllReduceType reduce_type, c10::ArrayRef<torch::lazy::Value> inputs,
const torch::lazy::Value& token, double scale, int64_t scatter_dim,
int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout)
: XlaNode(xla_reduce_scatter, GetOperandListWithToken(inputs, token),
[&]() {
return NodeOutputShapeCoalesced(reduce_type, inputs, token,
scale, scatter_dim, shard_count,
groups, pin_layout);
},
/*num_outputs=*/inputs.size() + 1,
torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale,
scatter_dim, shard_count, groups, pin_layout)),
reduce_type_(reduce_type),
scale_(scale),
scatter_dim_(scatter_dim),
shard_count_(shard_count),
groups_(std::move(groups)),
pin_layout_(pin_layout) {}

torch::lazy::NodePtr ReduceScatter::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<ReduceScatter>(
reduce_type_, operands.at(0), operands.at(1), scale_, scatter_dim_,
shard_count_, groups_, pin_layout_);
}

torch::lazy::NodePtr ReduceScatterCoalesced::Clone(
torch::lazy::OpList operands) const {
std::vector<torch::lazy::Value> inputs(operands.begin(), operands.end() - 1);
return torch::lazy::MakeNode<ReduceScatterCoalesced>(
reduce_type_, inputs, operands.back(), scale_, scatter_dim_, shard_count_,
groups_, pin_layout_);
}

XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp token = loctx->GetOutputOp(operand(1));
Expand All @@ -67,6 +116,21 @@ XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const {
return ReturnOps({result.result, result.token}, loctx);
}

XlaOpVector ReduceScatterCoalesced::Lower(LoweringContext* loctx) const {
auto& operand_list = operands();
std::vector<xla::XlaOp> inputs;
inputs.reserve(operand_list.size());
for (size_t i = 0; i + 1 < operand_list.size(); ++i) {
inputs.push_back(loctx->GetOutputOp(operand_list[i]));
}
xla::XlaOp token = loctx->GetOutputOp(operand_list.back());
ReduceScatterResultCoalesced result = BuildReduceScatterCoalesced(
reduce_type_, inputs, token, scale_, scatter_dim_, shard_count_, groups_,
pin_layout_);
result.result.push_back(result.token);
return ReturnOps(result.result, loctx);
}

std::string ReduceScatter::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString()
Expand All @@ -82,4 +146,18 @@ std::string ReduceScatter::ToString() const {
return ss.str();
}

std::string ReduceScatterCoalesced::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString()
<< ", reduce_type=" << torch::lazy::GetEnumValue(reduce_type_)
<< ", scale=" << scale_ << ", scatter_dim=" << scatter_dim_
<< ", shard_count=" << shard_count_ << ", pin_layout=" << pin_layout_
<< ", groups=(";
for (size_t i = 0; i < groups_.size(); ++i) {
ss << (i == 0 ? "(" : ",(");
ss << absl::StrJoin(groups_[i], ", ") << ")";
}
ss << ")";
return ss.str();
}
} // namespace torch_xla
Loading

0 comments on commit ae438f3

Please sign in to comment.