Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reduce-scatter coalescing for FSDP/ZeRO1 #5956

Merged
merged 4 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def _mp_fn(index):
scale = 1 / world_size
scatter_dim = 1
shard_size = 2
input_list_size = 5

if xm.xla_device_hw(device) in ['TPU', 'CUDA']:
rand = torch.rand((32, shard_size * world_size, 32))
Expand All @@ -25,8 +26,35 @@ def _mp_fn(index):
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)

assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter')

# Testing reduce-scatter with list input
rand_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xrand_list = [rand.to(device) for rand in rand_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
pin_layout=False)

for i, res in enumerate(res_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input')

else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
Expand Down
21 changes: 21 additions & 0 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,27 @@ def test_reduce_scatter(self):
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
hlo_matches(hlo, reduce_scatter_pattern)

@skipIf(xr.device_type() == 'CPU',
"UNIMPLEMENTED: ReduceScatter is not implemented on CPU.")
def test_reduce_scatter_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
input_tensors_list = [[tensor, tensor], [tensor2, tensor2]]
output_list = [torch.zeros_like(tensor), torch.zeros_like(tensor2)]
pg_xla = get_process_group_xla(rank=0, size=len(input_tensors_list[0]))
opts = dist.ReduceScatterOptions()
opts.reduceOp = dist.ReduceOp.SUM
reduce_scatter_pattern = (
r'%reduce\-scatter\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) '
r'reduce\-scatter\(s64\[4]\{0} %.+\.\d+, s64\[10]\{0} %.+\.\d+, '
r's64\[] %.+\.\d+\)')
pg_xla.reduce_scatter_coalesced(output_list, input_tensors_list, opts)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list)
hlo_matches(hlo, reduce_scatter_pattern)
# purge all computations attached the device.
xm.mark_step()

@patch_world(0, 6)
def test_send(self):
device = xm.xla_device()
Expand Down
60 changes: 42 additions & 18 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,16 +740,18 @@ def reduce_scatter(reduce_type,
reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
``xm.REDUCE_MAX``.
input: A single `torch.Tensor` all reduce + scatter op to.
input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
it will also be the output.
scale (float): A default scaling value to be applied after the reduce.
scatter_dim (int): Dimension number to which apply scatter operation.
shard_count (int): The number of ways to split up the scatter_dim in.
groups (list): A list of list, representing the replica groups for
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output: Optional output tensor
output: Optional output tensor if `input` is a torch.Tensor, or a list of
torch.Tensor if `input` is a list of torch.Tensor.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
Expand All @@ -762,21 +764,43 @@ def reduce_scatter(reduce_type,
the same as the input.
"""
token, devctx = _get_all_reduce_token()
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(reduce_type, output,
input, token, scale,
scatter_dim,
shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale,
scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]

if isinstance(input, torch.Tensor):
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
scale, scatter_dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]

# Now the input should be a list of Tensors.
elif isinstance(input, list) and all(
isinstance(v, torch.Tensor) for v in input):
if output != None:
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(input):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")
jeffhataws marked this conversation as resolved.
Show resolved Hide resolved
result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, output or [], input, token, scale, scatter_dim,
shard_count, groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]
else:
raise TypeError("`input` needs to be a Tensor or a list of Tensors, but "
f"given {type(input)}.")
jeffhataws marked this conversation as resolved.
Show resolved Hide resolved


def add_step_closure(closure, args=(), run_async=False):
Expand Down
48 changes: 48 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,54 @@ ReduceScatterResult BuildReduceScatter(
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
jeffhataws marked this conversation as resolved.
Show resolved Hide resolved
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
std::vector<xla::ReplicaGroup> cc_groups = CreateReduceGroups(groups);
TokenHandler token_handler(token);
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
// switched to use the real XLA Token once support has been added to XLA
// ReduceScatter().
ReduceContext cc_ctx = GetReduceContext(inputs);
std::vector<xla::XlaOp> result(inputs.size());
for (auto& type_ctx : cc_ctx.contexts) {
xla::XlaOp reduce_result;
type_ctx.second.ops[0] = token_handler.GetInput(
type_ctx.second.ops[0], &type_ctx.second.operand_shapes[0]);
if (pin_layout) {
reduce_result = xla::ReduceScatter(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
shard_count, cc_groups, /*channel_id=*/absl::nullopt,
/*layout=*/
MakeReduceShape(type_ctx.second.operand_shapes).layout());
} else {
reduce_result = xla::ReduceScatter(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
shard_count, cc_groups);
}
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
xla::XlaOp gte;
if (ShapeHelper::ShapeOfXlaOp(reduce_result).rank() == 0) {
gte = xla::GetTupleElement(reduce_result, i);
} else {
gte = reduce_result;
}
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, type_ctx.second.operand_shapes[i].element_type(),
gte.builder());
gte = gte * scaling_value;
}
result[op_idx] = gte;
}
}
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
return {result, token_handler.GetNewToken(result[0])};
}

std::vector<torch::lazy::Value> GetOperandListWithToken(
c10::ArrayRef<torch::lazy::Value> operands,
const torch::lazy::Value& token) {
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ struct ReduceScatterResult {
xla::XlaOp token;
};

struct ReduceScatterResultCoalesced {
std::vector<xla::XlaOp> result;
xla::XlaOp token;
};

std::vector<xla::XlaOp> BuildAllReduce(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
xla::XlaOp token, double scale,
Expand Down Expand Up @@ -91,6 +96,11 @@ ReduceScatterResult BuildReduceScatter(
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

std::vector<torch::lazy::Value> GetOperandListWithToken(
c10::ArrayRef<torch::lazy::Value> operands,
const torch::lazy::Value& token);
Expand Down
47 changes: 47 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,29 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> ReduceScatter(
std::make_shared<torch::lazy::Value>(new_token));
}

std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
ReduceScatterCoalesced(const std::string& reduce_type,
const std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token,
double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
Comment on lines +247 to +248
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to handle the outputs cases, it is better to define a ReduceScatterCoalescedOut instead of merging the logic into a single function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ok if I work on this in another change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's OK, but can we explictly error out in python api side if output is not None? I'd rather throw a explictly error on cases that we don't test/support yet.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can work on the reduce_scatter_out in a separate pr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jack. Here's the change to error out if output!=None: 9baadf5

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#6058 is for adding out-of-place version for reduce-scatter, along with #6059 for out-of-place all-gather.

std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced(
xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale,
scatter_dim, shard_count, replica_groups, pin_layout);
std::vector<at::Tensor> aten_result;
for (auto& xt : result) {
aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt)));
}
return {aten_result, std::make_shared<torch::lazy::Value>(new_token)};
}

std::shared_ptr<torch::lazy::Value> ReduceScatterOut(
const std::string& reduce_type, at::Tensor& output, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
Expand Down Expand Up @@ -1320,6 +1343,30 @@ void InitXlaModuleBindings(py::module m) {
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_reduce_scatter_coalesced",
[](const std::string& reduce_type, std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::vector<at::Tensor> result;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) = ReduceScatterCoalesced(
reduce_type, outputs, inputs, token, scale, scatter_dim,
shard_count, replica_groups, pin_layout);
}
auto result_list = py::list(result.size() + 1);
for (int i = 0; i < result.size(); ++i) {
result_list[i] = torch::autograd::make_variable(
result[i], /*requires_grad=*/result[i].requires_grad());
}
result_list[result.size()] = new_token;
return result_list;
});
m.def("_xla_reduce_scatter_out",
[](const std::string& reduce_type, at::Tensor& output,
const at::Tensor& input,
Expand Down
Loading