Skip to content

Commit

Permalink
[Distributed] Switch all_reduce to use the new functional collective …
Browse files Browse the repository at this point in the history
…op (#6887)

PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op.

See context in pytorch/pytorch#93173 (comment)
  • Loading branch information
yifuwang authored Apr 10, 2024
1 parent 756b0ec commit a816c42
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
3 changes: 1 addition & 2 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,7 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True):
if scale == 1.0 and groups == [] and pin_layout:
# TODO(alanwaketan): Support groups.
# Only c10d_functional version cc ops are traceable by Dynamo.
result = torch.ops.c10d_functional.all_reduce(inputs, reduce_type, "", [],
0)
result = torch.ops._c10d_functional.all_reduce(inputs, reduce_type, "")
else:
result = torch_xla._XLAC._xla_all_reduce(reduce_type, inputs, scale,
groups, pin_layout)
Expand Down
15 changes: 6 additions & 9 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,20 @@ std::shared_ptr<torch::lazy::Value> CreateToken(
// order. RFC: https://github.com/pytorch/pytorch/issues/93173
////////////////////////////////////////////////////////////////////////////////////

// tag is ignored as it's only used in PyTorch to provide backward compatibility
// with the traditional process group API.
at::Tensor all_reduce(const at::Tensor& self, c10::string_view reduceOp,
c10::string_view /*tag*/, at::IntArrayRef /*ranks*/,
int64_t /*group_size*/) {
at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
std::string /*group_name*/) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
auto self_tensor = bridge::GetXlaTensor(self);
// TODO(alanwaketan): Use ranks and group_size to generate groups. Currently
// we just suse {} as a workaround. Scale is always 1.0 here, and we always
// pin layout.
// TODO(alanwaketan): Use group_name to generate groups. Currently we just
// use {} as a workaround. Scale is always 1.0 here, and we always pin
// layout.
auto result = tensor_methods::all_reduce(self_tensor, GetReduceType(reduceOp),
/*scale*/ 1.0,
/*groups*/ {}, /*pin_layout*/ true);
return bridge::AtenFromXlaTensor(result);
}

TORCH_LIBRARY_IMPL(c10d_functional, XLA, m) {
TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("all_reduce", all_reduce);
}

Expand Down

0 comments on commit a816c42

Please sign in to comment.