Skip to content

Commit

Permalink
Add out-of-place reduce-scatter coalescing (#6058)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws authored and golechwierowicz committed Jan 12, 2024
1 parent 873c030 commit 7fdb15f
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 45 deletions.
28 changes: 28 additions & 0 deletions test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,34 @@ def _mp_fn(index):

xm.rendezvous('test_reduce_scatter_list_input')

# Testing reduce-scatter with list input and output
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
pin_layout=False)

assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output')
else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
Expand Down
21 changes: 16 additions & 5 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,13 +785,24 @@ def reduce_scatter(reduce_type,
elif isinstance(input, list) and all(
isinstance(v, torch.Tensor) for v in input):
if output != None:
raise RuntimeError(
"For xm.reduce_scatter with list of tensors input, output != None is not yet supported."
)
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(input):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, output or [], input, token, scale, scatter_dim,
shard_count, groups or [], pin_layout)
reduce_type, input, token, scale, scatter_dim, shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]
else:
Expand Down
94 changes: 62 additions & 32 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,39 +236,52 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> ReduceScatter(
std::make_shared<torch::lazy::Value>(new_token));
}

std::shared_ptr<torch::lazy::Value> ReduceScatterOut(
const std::string& reduce_type, at::Tensor& output, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
XLATensorPtr out = bridge::GetXlaTensor(output);
torch::lazy::Value new_token;
new_token = tensor_methods::reduce_scatter_out(
out, bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type),
scale, scatter_dim, shard_count, replica_groups, pin_layout);
return std::make_shared<torch::lazy::Value>(new_token);
}

std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
ReduceScatterCoalesced(const std::string& reduce_type,
const std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token,
double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced(
xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale,
scatter_dim, shard_count, replica_groups, pin_layout);
xtensors, *token, GetReduceType(reduce_type), scale, scatter_dim,
shard_count, replica_groups, pin_layout);
std::vector<at::Tensor> aten_result;
for (auto& xt : result) {
aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt)));
}
return {aten_result, std::make_shared<torch::lazy::Value>(new_token)};
}

std::shared_ptr<torch::lazy::Value> ReduceScatterOut(
const std::string& reduce_type, at::Tensor& output, const at::Tensor& input,
std::shared_ptr<torch::lazy::Value> ReduceScatterCoalescedOut(
const std::string& reduce_type, std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
XLATensorPtr out = bridge::GetXlaTensor(output);
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
torch::lazy::Value new_token;
new_token = tensor_methods::reduce_scatter_out(
out, bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type),
scale, scatter_dim, shard_count, replica_groups, pin_layout);
new_token = tensor_methods::reduce_scatter_coalesced_out(
xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale,
scatter_dim, shard_count, replica_groups, pin_layout);
return std::make_shared<torch::lazy::Value>(new_token);
}

Expand Down Expand Up @@ -1346,45 +1359,62 @@ void InitXlaModuleBindings(py::module m) {
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_reduce_scatter_coalesced",
[](const std::string& reduce_type, std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
m.def("_xla_reduce_scatter_out",
[](const std::string& reduce_type, at::Tensor& output,
const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::vector<at::Tensor> result;
at::Tensor result;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) = ReduceScatterCoalesced(
reduce_type, outputs, inputs, token, scale, scatter_dim,
shard_count, replica_groups, pin_layout);
}
auto result_list = py::list(result.size() + 1);
for (int i = 0; i < result.size(); ++i) {
result_list[i] = torch::autograd::make_variable(
result[i], /*requires_grad=*/result[i].requires_grad());
new_token = ReduceScatterOut(reduce_type, output, input, token,
scale, scatter_dim, shard_count,
replica_groups, pin_layout);
}
result_list[result.size()] = new_token;
return result_list;
return new_token;
});
m.def("_xla_reduce_scatter_out",
[](const std::string& reduce_type, at::Tensor& output,
const at::Tensor& input,
m.def(
"_xla_reduce_scatter_coalesced",
[](const std::string& reduce_type, const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::vector<at::Tensor> result;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) = ReduceScatterCoalesced(
reduce_type, inputs, token, scale, scatter_dim, shard_count,
replica_groups, pin_layout);
}
auto result_list = py::list(result.size() + 1);
for (int i = 0; i < result.size(); ++i) {
result_list[i] = torch::autograd::make_variable(
result[i], /*requires_grad=*/result[i].requires_grad());
}
result_list[result.size()] = new_token;
return result_list;
});
m.def("_xla_reduce_scatter_coalesced_out",
[](const std::string& reduce_type, std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
at::Tensor result;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
new_token = ReduceScatterOut(reduce_type, output, input, token,
scale, scatter_dim, shard_count,
replica_groups, pin_layout);
new_token = ReduceScatterCoalescedOut(
reduce_type, outputs, inputs, token, scale, scatter_dim,
shard_count, replica_groups, pin_layout);
}
return new_token;
});
Expand Down
27 changes: 21 additions & 6 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,12 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output,
}

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value>
reduce_scatter_coalesced(const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs,
reduce_scatter_coalesced(const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
XLA_CHECK(outputs.empty() || outputs.size() == inputs.size());
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
Expand All @@ -412,13 +410,30 @@ reduce_scatter_coalesced(const std::vector<XLATensorPtr>& outputs,
std::vector<XLATensorPtr> result;
for (size_t i = 0; i < inputs.size(); ++i) {
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
if (!outputs.empty()) {
outputs[i]->SetIrValue(torch::lazy::Value(node, i));
}
}
return {result, torch::lazy::Value(node, inputs.size())};
}

torch::lazy::Value reduce_scatter_coalesced_out(
const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<ReduceScatterCoalesced>(
reduce_type, input_values, token, scale, scatter_dim, shard_count,
std::move(groups), pin_layout);
for (size_t i = 0; i < inputs.size(); ++i) {
outputs[i]->SetIrValue(torch::lazy::Value(node, i));
}
return torch::lazy::Value(node, inputs.size());
}

std::pair<XLATensorPtr, torch::lazy::Value> all_to_all(
const XLATensorPtr& input, const torch::lazy::Value& token,
int64_t split_dimension, int64_t concat_dimension, int64_t split_count,
Expand Down
10 changes: 8 additions & 2 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output,
bool pin_layout);

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value>
reduce_scatter_coalesced(const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs,
reduce_scatter_coalesced(const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout);

torch::lazy::Value reduce_scatter_coalesced_out(
const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout);

std::pair<XLATensorPtr, torch::lazy::Value> all_to_all(
const XLATensorPtr& input, const torch::lazy::Value& token,
int64_t split_dimension, int64_t concat_dimension, int64_t split_count,
Expand Down

0 comments on commit 7fdb15f

Please sign in to comment.