Skip to content

Commit

Permalink
Fix a bug in native::_do_all_gather related to group (#2947)
Browse files Browse the repository at this point in the history
* Implemented the feature

* Remove a comment

* Fix bugs

* Remove tensor_with_different_shape

* Remove tensor_with_different_shape
  • Loading branch information
sadra-barikbin authored Jun 19, 2023
1 parent 4c83da1 commit ffed9f3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 18 deletions.
17 changes: 11 additions & 6 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,15 +430,20 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
return tensor

def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
if group is not None and not isinstance(group, dist.ProcessGroup):
if group == dist.GroupMember.NON_GROUP_MEMBER:
return tensor
elif group is None:
group_size = self.get_world_size()
elif isinstance(group, dist.ProcessGroup):
group_size = group.size()
elif isinstance(group, list):
group_size = len(group)
else:
raise ValueError("Argument group should be list of int or ProcessGroup")
if tensor.ndimension() == 0:
tensor = tensor.unsqueeze(0)
output = [torch.zeros_like(tensor) for _ in range(self.get_world_size())]
if group is not None:
dist.all_gather(output, tensor, group=group)
else:
dist.all_gather(output, tensor)
output = [torch.zeros_like(tensor) for _ in range(group_size)]
dist.all_gather(output, tensor, group=group)
return torch.cat(output, dim=0)

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
Expand Down
10 changes: 6 additions & 4 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,15 @@ def all_gather(
"""Helper method to perform all gather operation.
Args:
tensor: tensor or number or str to collect across participating processes.
tensor: tensor or number or str to collect across participating processes. If tensor, it should have the
same shape across processes.
group: list of integer or the process group for each backend. If None, the default process group will be used.
Returns:
torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or
torch.Tensor of shape ``(world_size, )`` if input is a number or
List of strings if input is a string
If input is a tensor, returns a torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)``.
If input is a number, a torch.Tensor of shape ``(world_size, )`` is returned and finally a list of strings
is returned if input is a string. If current process does not belong to `group`, the very ``tensor`` is
returned.
.. versionchanged:: 0.4.11
added ``group``
Expand Down
22 changes: 14 additions & 8 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,32 +155,34 @@ def _test_distrib_all_reduce_group(device):


def _test_distrib_all_gather(device):
rank = idist.get_rank()

res = torch.tensor(idist.all_gather(10), device=device)
true_res = torch.tensor([10] * idist.get_world_size(), device=device)
assert (res == true_res).all()

t = torch.tensor(idist.get_rank(), device=device)
t = torch.tensor(rank, device=device)
res = idist.all_gather(t)
true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device)
assert (res == true_res).all()

x = "test-test"
if idist.get_rank() == 0:
if rank == 0:
x = "abc"
res = idist.all_gather(x)
true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1)
assert res == true_res

base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
x = base_x
if idist.get_rank() == 0:
if rank == 0:
x = "abc"

res = idist.all_gather(x)
true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1)
assert res == true_res

t = torch.arange(100, device=device).reshape(4, 25) * (idist.get_rank() + 1)
t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1)
in_dtype = t.dtype
res = idist.all_gather(t)
assert res.shape == (idist.get_world_size() * 4, 25)
Expand Down Expand Up @@ -208,17 +210,21 @@ def _test_distrib_all_gather_group(device):
res = idist.all_gather(t, group=group)
else:
res = idist.all_gather(t, group=group)
assert torch.equal(res, torch.tensor(ranks, device=device))
if rank in ranks:
assert torch.equal(res, torch.tensor(ranks, device=device))
else:
assert res == t

t = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group=ranks)
else:
res = idist.all_gather(t, group=ranks)
assert torch.equal(res, torch.tensor(ranks, device=device))

ranks = "abc"
if rank in ranks:
assert torch.equal(res, torch.tensor(ranks, device=device))
else:
assert res == t

if bnd in ("nccl", "gloo", "mpi"):
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
Expand Down

0 comments on commit ffed9f3

Please sign in to comment.