Skip to content

Commit

Permalink
Merge pull request #8 from jeffhataws/v2.1.1_updated_requires_grad_fix
Browse files Browse the repository at this point in the history
Propagates requires_grad over to AllReduce output (pytorch#6326)
  • Loading branch information
jeffhataws authored Jan 22, 2024
2 parents 8cf4f69 + 23f8974 commit 619c7c8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 4 additions & 2 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ def _all_reduce(pin_layout):
device = xm.xla_device()
# Prevent 0 and 1 from being converted to constants
ordinal = xm.send_cpu_data_to_device(
torch.tensor(xm.get_ordinal()), device=device)
torch.tensor(xm.get_ordinal(), dtype=torch.float32, requires_grad=True),
device=device)
out = xm.all_reduce(xm.REDUCE_SUM, ordinal, pin_layout=pin_layout)[0]
assert out.requires_grad
xm.mark_step()

return out.cpu().numpy()
return out.cpu().detach().numpy()

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_all_reduce(self, pin_layout):
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,8 @@ void InitXlaModuleBindings(py::module m) {
NoGilSection nogil;
result = AllReduce(reduce_type, input, scale, replica_groups, pin_layout);
}
return result;
return torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
});
m.def("_xla_all_to_all",
[](const at::Tensor& input,
Expand Down

0 comments on commit 619c7c8

Please sign in to comment.