Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add out-of-place all-gather coalesced #6059

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def all_gather(tensor, dim):
def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
input_list_size = 5
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'):
# Testing with a single replica group
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)
Expand Down Expand Up @@ -57,6 +58,48 @@ def _mp_fn(index):
f'Failed to create two replica groups with {world_size} replicas',
file=sys.stderr)

# Testing with a single replica group and tensor list as input
ordinal_tensors = [
torch.tensor([i * 1000 + index], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output=None
result_list = xm.all_gather(ordinal_tensors, dim=0, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input and output!=None (out-of-place)
ordinal_tensors = [
jeffhataws marked this conversation as resolved.
Show resolved Hide resolved
torch.tensor([i * 1000 + index], dtype=torch.float).to(device)
for i in range(input_list_size)
]
output_tensors = [
torch.zeros([world_size], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output!=None
result_list = xm.all_gather(
ordinal_tensors, dim=0, output=output_tensors, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)
# TODO: add test for torch.compile when support for list input is ready

else:
print(f'{device} is not a TPU or GPU device', file=sys.stderr)

Expand Down
21 changes: 20 additions & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
if pin_layout and (output == None or xla_device_hw(value.device) == 'NEURON'):
if pin_layout and output == None and isinstance(value, torch.Tensor):
jeffhataws marked this conversation as resolved.
Show resolved Hide resolved
# There is not an easy way to pin the all_gather layout on TPU, GPU and NEURON,
# use all_reduce based all_gather for this purpose.
return _all_gather_using_all_reduce(
Expand Down Expand Up @@ -587,6 +587,25 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
# Now the input should be a list of Tensors.
elif isinstance(value, list) and all(
isinstance(v, torch.Tensor) for v in value):
if pin_layout:
raise RuntimeError(
"For xm.all_gather with list of tensors input, pin_layout=True is not yet supported."
)
if output != None:
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(value):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_all_gather_coalesced_out(
output, value, token, dim, shard_count, groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
shard_count, groups or
[], pin_layout)
Expand Down
45 changes: 38 additions & 7 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,19 @@ at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count,
return bridge::AtenFromXlaTensor(std::move(result));
}

std::shared_ptr<torch::lazy::Value> AllGatherOut(
at::Tensor& output, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
XLATensorPtr out = bridge::GetXlaTensor(output);
torch::lazy::Value new_token;
new_token = tensor_methods::all_gather_out(out, bridge::GetXlaTensor(input),
*token, dim, shard_count,
replica_groups, pin_layout);
return std::make_shared<torch::lazy::Value>(new_token);
}

std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
AllGatherCoalesced(const std::vector<at::Tensor>& tensors,
const std::shared_ptr<torch::lazy::Value>& token,
Expand All @@ -291,7 +304,7 @@ AllGatherCoalesced(const std::vector<at::Tensor>& tensors,
GetXlaTensors(tensors, /*want_all=*/true);
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::all_gather(
std::tie(result, new_token) = tensor_methods::all_gather_coalesced(
xtensors, *token, dim, shard_count, replica_groups, pin_layout);
std::vector<at::Tensor> aten_result;
for (auto& xt : result) {
Expand All @@ -300,16 +313,18 @@ AllGatherCoalesced(const std::vector<at::Tensor>& tensors,
return {aten_result, std::make_shared<torch::lazy::Value>(new_token)};
}

std::shared_ptr<torch::lazy::Value> AllGatherOut(
at::Tensor& output, const at::Tensor& input,
std::shared_ptr<torch::lazy::Value> AllGatherCoalescedOut(
std::vector<at::Tensor>& outputs, const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
XLATensorPtr out = bridge::GetXlaTensor(output);
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
torch::lazy::Value new_token;
new_token = tensor_methods::all_gather_out(out, bridge::GetXlaTensor(input),
*token, dim, shard_count,
replica_groups, pin_layout);
new_token = tensor_methods::all_gather_coalesced_out(
xtensors_out, xtensors, *token, dim, shard_count, replica_groups,
pin_layout);
return std::make_shared<torch::lazy::Value>(new_token);
}

Expand Down Expand Up @@ -1275,6 +1290,22 @@ void InitXlaModuleBindings(py::module m) {
result_list[results.size()] = new_token;
return result_list;
});
m.def("_xla_all_gather_coalesced_out",
[](std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
int64_t shard_count, const py::list& groups, bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
new_token =
AllGatherCoalescedOut(outputs, inputs, token, dim, shard_count,
replica_groups, pin_layout);
}
return new_token;
});
m.def("_xla_collective_permute",
[](const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token,
Expand Down
53 changes: 35 additions & 18 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,24 +430,6 @@ std::pair<XLATensorPtr, torch::lazy::Value> all_to_all(
torch::lazy::Value(node, 1)};
}

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> all_gather(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGatherCoalesced>(
input_values, token, dim, shard_count, std::move(groups), pin_layout);
std::vector<XLATensorPtr> result;
for (size_t i = 0; i < inputs.size(); ++i) {
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
}
return {result, torch::lazy::Value(node, inputs.size())};
}

XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim,
int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
Expand All @@ -473,6 +455,41 @@ torch::lazy::Value all_gather_out(XLATensorPtr& output,
return torch::lazy::Value(node, 1);
}

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> all_gather_coalesced(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGatherCoalesced>(
input_values, token, dim, shard_count, std::move(groups), pin_layout);
std::vector<XLATensorPtr> result;
for (size_t i = 0; i < inputs.size(); ++i) {
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
}
return {result, torch::lazy::Value(node, inputs.size())};
}

torch::lazy::Value all_gather_coalesced_out(
std::vector<XLATensorPtr>& outputs, const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token, int64_t dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGatherCoalesced>(
input_values, token, dim, shard_count, std::move(groups), pin_layout);
for (size_t i = 0; i < inputs.size(); ++i) {
outputs[i]->SetIrValue(torch::lazy::Value(node, i));
}
return torch::lazy::Value(node, inputs.size());
}

std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
const XLATensorPtr& input, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs) {
Expand Down
15 changes: 10 additions & 5 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ std::pair<XLATensorPtr, torch::lazy::Value> all_to_all(
int64_t split_dimension, int64_t concat_dimension, int64_t split_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout);

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> all_gather(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout);

XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim,
int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
Expand All @@ -64,6 +59,16 @@ torch::lazy::Value all_gather_out(XLATensorPtr& output,
std::vector<std::vector<int64_t>> groups,
bool pin_layout);

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> all_gather_coalesced(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout);

torch::lazy::Value all_gather_coalesced_out(
std::vector<XLATensorPtr>& outputs, const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token, int64_t dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout);

std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
const XLATensorPtr& input, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);
Expand Down