Skip to content

Commit

Permalink
[BE][DTensor] fix DTensor equal op
Browse files Browse the repository at this point in the history
ghstack-source-id: b5445b8b47af08c6d68506b9e43dadc1d1a28f60
Pull Request resolved: #99014
  • Loading branch information
XilunWu committed Apr 13, 2023
1 parent 3b78485 commit b5e716a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 33 deletions.
8 changes: 2 additions & 6 deletions test/distributed/_tensor/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,8 @@ def test_equal(self):
dist_tensor_2 = DTensor.from_local(input_tensor_2, device_mesh, shard_spec)

eq_result = dist_tensor_1.equal(dist_tensor_2)
if self.rank == 0:
# TODO: equal op currently returns each shard's local result
with self.assertRaises(AssertionError):
self.assertFalse(eq_result)
else:
self.assertFalse(eq_result)
# equal op all reduces each shard's local result
self.assertFalse(eq_result)

def _test_op(self, mesh, op_call, *args, **kwargs):
out = op_call(*args, **kwargs)
Expand Down
50 changes: 23 additions & 27 deletions torch/distributed/_tensor/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import functools
import operator
from typing import Callable, cast, Dict, List, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -163,9 +164,10 @@ def _operator_dispatch(

if mesh is not None and mesh.get_coordinate() is None:
# For a non-participating device, we do:
# 1. if the return type is scalar, all gather the local result
# from participating devices, and reduce on the list of results
# with appropriate operators.
# 1. if the return type is scalar, set the local result to None.
# The local results from all devices will then be all-gathered
# and a reduce op will be performed on the list of results
# with appropriate operators:
# for bool type, we by default use AND to reduce;
# we can extend for more ops if necessary.
# 2. if the return type is Tensor or List[Tensor], return empty
Expand All @@ -180,21 +182,9 @@ def _operator_dispatch(
)

if spec is None:
# return a scalar value
# collect local results from participating ranks
obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, None)
obj_list = list(filter(lambda x: x is not None, obj_list))
# perform reduce on the collection with AND op
ret_type = str(ret_list[0].type)
if ret_type == "bool":
import operator

local_results: object = functools.reduce(operator.and_, obj_list, True)
else:
raise NotImplementedError(
f"return type {ret_type} in DTensor op is not supported"
)
# For a scalar return type, the non-participating device has None
# as its local result
local_results: object = None
else:

def default_tensor(spec: DTensorSpec) -> torch.Tensor:
Expand Down Expand Up @@ -255,20 +245,26 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
local_tensor_kwargs = cast(Dict[str, object], local_tensor_kwargs)
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
if (
(mesh is not None)
and (mesh.mesh.numel() < dist.get_world_size())
and (output_sharding.output_spec is None)
):
# communicate the result to non-participating ranks if
# op runs on a submesh and return type is scalar value
obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, local_results)

# if op is a random op, adjust Philox RNG state to maintain synchronization
if op_call in random_ops and is_rng_supported_mesh(mesh):
set_post_op_offset(dtensor_arg._spec, old_offset)

# NOTE: can we move this logic into wrap() function???
# NOTE: operator _local_scalar_dense has return type "number" but does not
# have a straightforward pattern on how to reduce or even if we should reduce
# the local results
# communicate the result to all ranks for some operators that return scalar value
if output_sharding.output_spec is None:
if op_call == torch.ops.aten.equal.default:
obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, local_results)
obj_list = list(filter(lambda x: x is not None, obj_list))
ret_list = op_schema.func_schema.returns
ret_type = str(ret_list[0].type)
# perform reduce on the collection with AND op
local_results = functools.reduce(operator.and_, obj_list, True)

if suggested_input_schema.is_inplace:
# inplace op should return self instead of re-wrapping
self = cast(dtensor.DTensor, args[0])
Expand Down

0 comments on commit b5e716a

Please sign in to comment.