From d8947026f456bbd7d10c4f3059e935b39c5d4ec1 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Fri, 8 Dec 2023 06:14:45 +0000 Subject: [PATCH 1/3] Add out-of-place all-gather coalesced --- torch_xla/core/xla_model.py | 15 +++++++ torch_xla/csrc/init_python_bindings.cpp | 45 +++++++++++++++++---- torch_xla/csrc/tensor_methods.cpp | 53 ++++++++++++++++--------- torch_xla/csrc/tensor_methods.h | 15 ++++--- 4 files changed, 98 insertions(+), 30 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index d52d22d6e41..8bf298be2de 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -587,6 +587,21 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): # Now the input should be a list of Tensors. elif isinstance(value, list) and all( isinstance(v, torch.Tensor) for v in value): + if output != None: + if not isinstance(output, list) or any( + not isinstance(v, torch.Tensor) for v in output): + raise TypeError( + f"`output` needs to be a list of Tensors, but given {type(output)}." + ) + if len(output) != len(input): + raise ValueError("`output` length doesn't match `input` length: " + f"{len(output)} vs {len(input)}.") + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( + output, value, token, dim, shard_count, groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output + result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim, shard_count, groups or [], pin_layout) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c3ab1fb7857..99a3fd492e1 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -281,6 +281,19 @@ at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count, return bridge::AtenFromXlaTensor(std::move(result)); } +std::shared_ptr AllGatherOut( + at::Tensor& output, const at::Tensor& input, + const std::shared_ptr& token, int64_t dim, + int64_t shard_count, + const std::vector>& replica_groups, bool pin_layout) { + XLATensorPtr out = bridge::GetXlaTensor(output); + torch::lazy::Value new_token; + new_token = tensor_methods::all_gather_out(out, bridge::GetXlaTensor(input), + *token, dim, shard_count, + replica_groups, pin_layout); + return std::make_shared(new_token); +} + std::pair, std::shared_ptr> AllGatherCoalesced(const std::vector& tensors, const std::shared_ptr& token, @@ -291,7 +304,7 @@ AllGatherCoalesced(const std::vector& tensors, GetXlaTensors(tensors, /*want_all=*/true); std::vector result; torch::lazy::Value new_token; - std::tie(result, new_token) = tensor_methods::all_gather( + std::tie(result, new_token) = tensor_methods::all_gather_coalesced( xtensors, *token, dim, shard_count, replica_groups, pin_layout); std::vector aten_result; for (auto& xt : result) { @@ -300,16 +313,18 @@ AllGatherCoalesced(const std::vector& tensors, return {aten_result, std::make_shared(new_token)}; } -std::shared_ptr AllGatherOut( - at::Tensor& output, const at::Tensor& input, +std::shared_ptr AllGatherCoalescedOut( + std::vector& outputs, const std::vector& inputs, const std::shared_ptr& token, int64_t dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { - XLATensorPtr out = bridge::GetXlaTensor(output); + std::vector xtensors_out = + GetXlaTensors(outputs, /*want_all=*/true); + std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); torch::lazy::Value new_token; - new_token = tensor_methods::all_gather_out(out, bridge::GetXlaTensor(input), - *token, dim, shard_count, - replica_groups, pin_layout); + new_token = tensor_methods::all_gather_coalesced_out( + xtensors_out, xtensors, *token, dim, shard_count, replica_groups, + pin_layout); return std::make_shared(new_token); } @@ -1275,6 +1290,22 @@ void InitXlaModuleBindings(py::module m) { result_list[results.size()] = new_token; return result_list; }); + m.def("_xla_all_gather_coalesced_out", + [](std::vector& outputs, + const std::vector& inputs, + const std::shared_ptr& token, int64_t dim, + int64_t shard_count, const py::list& groups, bool pin_layout) { + std::vector> replica_groups = + CreateReduceGroups(groups); + std::shared_ptr new_token; + { + NoGilSection nogil; + new_token = + AllGatherCoalescedOut(outputs, inputs, token, dim, shard_count, + replica_groups, pin_layout); + } + return new_token; + }); m.def("_xla_collective_permute", [](const at::Tensor& input, const std::shared_ptr& token, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 818b31e474a..4ccc83e5436 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -430,24 +430,6 @@ std::pair all_to_all( torch::lazy::Value(node, 1)}; } -std::pair, torch::lazy::Value> all_gather( - const std::vector& inputs, const torch::lazy::Value& token, - int64_t dim, int64_t shard_count, std::vector> groups, - bool pin_layout) { - std::vector input_values; - input_values.reserve(inputs.size()); - for (auto& input : inputs) { - input_values.push_back(input->GetIrValue()); - } - torch::lazy::NodePtr node = torch::lazy::MakeNode( - input_values, token, dim, shard_count, std::move(groups), pin_layout); - std::vector result; - for (size_t i = 0; i < inputs.size(); ++i) { - result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i))); - } - return {result, torch::lazy::Value(node, inputs.size())}; -} - XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim, int64_t shard_count, std::vector> groups, @@ -473,6 +455,41 @@ torch::lazy::Value all_gather_out(XLATensorPtr& output, return torch::lazy::Value(node, 1); } +std::pair, torch::lazy::Value> all_gather_coalesced( + const std::vector& inputs, const torch::lazy::Value& token, + int64_t dim, int64_t shard_count, std::vector> groups, + bool pin_layout) { + std::vector input_values; + input_values.reserve(inputs.size()); + for (auto& input : inputs) { + input_values.push_back(input->GetIrValue()); + } + torch::lazy::NodePtr node = torch::lazy::MakeNode( + input_values, token, dim, shard_count, std::move(groups), pin_layout); + std::vector result; + for (size_t i = 0; i < inputs.size(); ++i) { + result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i))); + } + return {result, torch::lazy::Value(node, inputs.size())}; +} + +torch::lazy::Value all_gather_coalesced_out( + std::vector& outputs, const std::vector& inputs, + const torch::lazy::Value& token, int64_t dim, int64_t shard_count, + std::vector> groups, bool pin_layout) { + std::vector input_values; + input_values.reserve(inputs.size()); + for (auto& input : inputs) { + input_values.push_back(input->GetIrValue()); + } + torch::lazy::NodePtr node = torch::lazy::MakeNode( + input_values, token, dim, shard_count, std::move(groups), pin_layout); + for (size_t i = 0; i < inputs.size(); ++i) { + outputs[i]->SetIrValue(torch::lazy::Value(node, i)); + } + return torch::lazy::Value(node, inputs.size()); +} + std::pair collective_permute( const XLATensorPtr& input, const torch::lazy::Value& token, std::vector> source_target_pairs) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 9d8226e2154..11cd8b2e3c5 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -47,11 +47,6 @@ std::pair all_to_all( int64_t split_dimension, int64_t concat_dimension, int64_t split_count, std::vector> groups, bool pin_layout); -std::pair, torch::lazy::Value> all_gather( - const std::vector& inputs, const torch::lazy::Value& token, - int64_t dim, int64_t shard_count, std::vector> groups, - bool pin_layout); - XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim, int64_t shard_count, std::vector> groups, @@ -64,6 +59,16 @@ torch::lazy::Value all_gather_out(XLATensorPtr& output, std::vector> groups, bool pin_layout); +std::pair, torch::lazy::Value> all_gather_coalesced( + const std::vector& inputs, const torch::lazy::Value& token, + int64_t dim, int64_t shard_count, std::vector> groups, + bool pin_layout); + +torch::lazy::Value all_gather_coalesced_out( + std::vector& outputs, const std::vector& inputs, + const torch::lazy::Value& token, int64_t dim, int64_t shard_count, + std::vector> groups, bool pin_layout); + std::pair collective_permute( const XLATensorPtr& input, const torch::lazy::Value& token, std::vector> source_target_pairs); From 0d6125cff87295a975302de81bbbee4ff2117285 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Sat, 9 Dec 2023 07:00:39 +0000 Subject: [PATCH 2/3] Add all-gather coalesce tests; fix all-gather coalesce bug "len(input)" --- test/test_mp_all_gather.py | 45 ++++++++++++++++++++++++++++++++++++- torch_xla/core/xla_model.py | 8 +++++-- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 3ffeebc963d..626573aa6d2 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -13,7 +13,8 @@ def all_gather(tensor, dim): def _mp_fn(index): device = xm.xla_device() world_size = xm.xrt_world_size() - if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): + input_list_size = 5 + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'): # Testing with a single replica group ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor, dim=0) @@ -57,6 +58,48 @@ def _mp_fn(index): f'Failed to create two replica groups with {world_size} replicas', file=sys.stderr) + # Testing with a single replica group and tensor list as input + ordinal_tensors = [ + torch.tensor([i * 1000 + index], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + # TODO: add support for list input with pin_layout=True and output=None + result_list = xm.all_gather(ordinal_tensors, dim=0, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + + # Testing with a single replica group and tensor list as input and output!=None (out-of-place) + ordinal_tensors = [ + torch.tensor([i * 1000 + index], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + output_tensors = [ + torch.zeros([world_size], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + # TODO: add support for list input with pin_layout=True and output!=None + result_list = xm.all_gather( + ordinal_tensors, dim=0, output=output_tensors, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + # TODO: add test for torch.compile when support for list input is ready + else: print(f'{device} is not a TPU or GPU device', file=sys.stderr) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 8bf298be2de..9f7cb74e6b9 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -553,7 +553,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): A tensor which has, in the ``dim`` dimension, all the values from the participating replicas. """ - if pin_layout and (output == None or xla_device_hw(value.device) == 'NEURON'): + if pin_layout and output == None and isinstance(value, torch.Tensor): # There is not an easy way to pin the all_gather layout on TPU, GPU and NEURON, # use all_reduce based all_gather for this purpose. return _all_gather_using_all_reduce( @@ -587,13 +587,17 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): # Now the input should be a list of Tensors. elif isinstance(value, list) and all( isinstance(v, torch.Tensor) for v in value): + if pin_layout: + raise RuntimeError( + "For xm.all_gather with list of tensors input, pin_layout=True is not yet supported." + ) if output != None: if not isinstance(output, list) or any( not isinstance(v, torch.Tensor) for v in output): raise TypeError( f"`output` needs to be a list of Tensors, but given {type(output)}." ) - if len(output) != len(input): + if len(output) != len(value): raise ValueError("`output` length doesn't match `input` length: " f"{len(output)} vs {len(input)}.") # Call the out of place version of the reduce_scatter From 554280646ed8981831dcfed5d0bfa196e56b391b Mon Sep 17 00:00:00 2001 From: Arjun Balasubramanian Date: Mon, 11 Dec 2023 22:40:30 +0000 Subject: [PATCH 3/3] Reuse ordinal_tensors; _all_gather_using_all_reduce can't accept list of tensors --- test/test_mp_all_gather.py | 5 +---- torch_xla/core/xla_model.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 626573aa6d2..c7de7361147 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -77,10 +77,7 @@ def _mp_fn(index): sys.exit(1) # Testing with a single replica group and tensor list as input and output!=None (out-of-place) - ordinal_tensors = [ - torch.tensor([i * 1000 + index], dtype=torch.float).to(device) - for i in range(input_list_size) - ] + # Reuse ordinal_tensors from previous test output_tensors = [ torch.zeros([world_size], dtype=torch.float).to(device) for i in range(input_list_size) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 9f7cb74e6b9..1dc3b785454 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -553,6 +553,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): A tensor which has, in the ``dim`` dimension, all the values from the participating replicas. """ + # _all_gather_using_all_reduce does not support list of tensors as input if pin_layout and output == None and isinstance(value, torch.Tensor): # There is not an easy way to pin the all_gather layout on TPU, GPU and NEURON, # use all_reduce based all_gather for this purpose.