From 40727e4d367e183baccd9a2ce734ca7632ca09ac Mon Sep 17 00:00:00 2001 From: Yeounoh Chung Date: Thu, 18 Jan 2024 18:12:13 -0800 Subject: [PATCH] Propagates requires_grad over to AllReduce output (#6326) --- test/pjrt/test_collective_ops_tpu.py | 6 ++++-- torch_xla/csrc/init_python_bindings.cpp | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index f1752901661..cad47eac13f 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -38,11 +38,13 @@ def _all_reduce(pin_layout): device = xm.xla_device() # Prevent 0 and 1 from being converted to constants ordinal = xm.send_cpu_data_to_device( - torch.tensor(xm.get_ordinal()), device=device) + torch.tensor(xm.get_ordinal(), dtype=torch.float32, requires_grad=True), + device=device) out = xm.all_reduce(xm.REDUCE_SUM, ordinal, pin_layout=pin_layout)[0] + assert out.requires_grad xm.mark_step() - return out.cpu().numpy() + return out.cpu().detach().numpy() @parameterized.named_parameters(('pinned', True), ('unpinned', False)) def test_all_reduce(self, pin_layout): diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 86b9c896b1a..3281f0e9a67 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1215,7 +1215,8 @@ void InitXlaModuleBindings(py::module m) { NoGilSection nogil; result = AllReduce(reduce_type, input, scale, replica_groups, pin_layout); } - return result; + return torch::autograd::make_variable( + result, /*requires_grad=*/input.requires_grad()); }); m.def("_xla_quantize_tensor", [](const at::Tensor& input, const std::vector& scale_list,