From a7e1c6c41a9d31d4dd553d64da73a146adf7c0d1 Mon Sep 17 00:00:00 2001 From: ybc Date: Mon, 22 Feb 2021 22:48:45 -0800 Subject: [PATCH 1/6] Add construct_topology function --- bluefog/torch/__init__.py | 2 ++ bluefog/torch/utility.py | 62 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/bluefog/torch/__init__.py b/bluefog/torch/__init__.py index 62695e4f..05b0182f 100644 --- a/bluefog/torch/__init__.py +++ b/bluefog/torch/__init__.py @@ -74,4 +74,6 @@ from bluefog.torch.mpi_ops import timeline_start_activity, timeline_end_activity from bluefog.torch.mpi_ops import timeline_context + from bluefog.torch.utility import broadcast_optimizer_state, broadcast_parameters, allreduce_parameters +from bluefog.torch.utility import construct_topology diff --git a/bluefog/torch/utility.py b/bluefog/torch/utility.py index b4e05238..d63b816e 100644 --- a/bluefog/torch/utility.py +++ b/bluefog/torch/utility.py @@ -14,8 +14,10 @@ # limitations under the License. # ============================================================================== +from typing import Any, List, Optional import collections +import numpy as np import torch import bluefog.torch as bf @@ -210,3 +212,63 @@ def _from_tensor(): for key, p in params: if key in callbacks: callbacks[key]() + + +def _check_ranks(rank_list: List[Any], self_rank: int, size: int) -> [bool, str]: + result = False + for rank in rank_list: + if not isinstance(rank, int): + return False, "contain element that is not integer." + if (rank < 0) or (rank >= size): + return False, "contain element that is not between 0 and size-1." + if len(set(rank_list)) != len(rank_list): + return False, "contain duplicated elements." + if self_rank in rank_list: + return False, "contain self rank." + return True, "" + + +def construct_topology( + *, + dst_list: Optional[List[int]] = None, + src_list: Optional[List[int]] = None, + construct_adjacency_matrix: bool = False +): + """ + + """ + if dst_list is None and src_list is None: + raise ValueError("Either dst_list or src_list need to be provided.") + if dst_list is not None and src_list is not None: + raise ValueError( + "Only one of two argument dst_list or src_list should be provided.") + + rank_list = dst_list or src_list + print(bf.rank() in rank_list) + is_valid, error_msg = _check_ranks(rank_list, bf.rank(), bf.size()) + assert is_valid, f"The format of dst_list or src_list is wrong: {error_msg}" + + degree = len(rank_list) + all_degree_list = bf.allgather(torch.IntTensor([degree])).numpy() + all_rank_list = bf.allgather(torch.IntTensor(rank_list)).numpy() + adjacency_dict = dict() + displacement = 0 + for i, degree in enumerate(all_degree_list): + adjacency_dict[i] = sorted(all_rank_list[displacement:displacement+degree]) + displacement += degree + + inv_adjacency_dict = collections.defaultdict(list) + for k, adj in adjacency_dict.items(): + for v in adj: + inv_adjacency_dict[v].append(k) + if not construct_adjacency_matrix: + return inv_adjacency_dict.get(bf.rank()) + + # construct_adjacency_matrix + W = np.eye(bf.size()) + for k, adj in adjacency_dict.items(): + W[k, adj] = 1 + if dst_list is None: + W = W.T + + return W / W.sum(axis=0) \ No newline at end of file From 12a734b2dbca0afc5f9ee0f9895f7deda260ea27 Mon Sep 17 00:00:00 2001 From: ybc Date: Sat, 27 Feb 2021 22:58:12 -0800 Subject: [PATCH 2/6] Add doc and test for infer_destination_source_ranks --- bluefog/torch/__init__.py | 31 ++++++++-- bluefog/torch/topology_util.py | 104 +++++++++++++++++++++++++++++++++ bluefog/torch/utility.py | 60 ------------------- test/torch_basics_test.py | 82 +++++++++++++++++++------- 4 files changed, 192 insertions(+), 85 deletions(-) create mode 100644 bluefog/torch/topology_util.py diff --git a/bluefog/torch/__init__.py b/bluefog/torch/__init__.py index 05b0182f..cf06b041 100644 --- a/bluefog/torch/__init__.py +++ b/bluefog/torch/__init__.py @@ -14,10 +14,6 @@ # limitations under the License. # ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import collections import os import torch @@ -76,4 +72,29 @@ from bluefog.torch.mpi_ops import timeline_context from bluefog.torch.utility import broadcast_optimizer_state, broadcast_parameters, allreduce_parameters -from bluefog.torch.utility import construct_topology + +from bluefog.common.topology_util import ( + GetRecvWeights, + GetSendWeights, + IsRegularGraph, + IsTopologyEquivalent, +) + +from bluefog.common.topology_util import ( + ExponentialTwoGraph, + ExponentialGraph, + FullyConnectedGraph, + MeshGrid2DGraph, + RingGraph, + StarGraph, + SymmetricExponentialGraph, +) + +from bluefog.common.topology_util import ( + GetDynamicOnePeerSendRecvRanks, + GetExp2DynamicSendRecvMachineRanks, + GetInnerOuterRingDynamicSendRecvRanks, + GetInnerOuterExpo2DynamicSendRecvRanks +) + +from bluefog.torch.topology_util import infer_destination_source_ranks diff --git a/bluefog/torch/topology_util.py b/bluefog/torch/topology_util.py new file mode 100644 index 00000000..dbafc610 --- /dev/null +++ b/bluefog/torch/topology_util.py @@ -0,0 +1,104 @@ +from typing import Any, List, Optional, Union +import collections + +import numpy as np +import torch +import bluefog.torch as bf + +from bluefog.common.topology_util import ( + GetRecvWeights, + GetSendWeights, + IsRegularGraph, + IsTopologyEquivalent, +) + +from bluefog.common.topology_util import ( + ExponentialTwoGraph, + ExponentialGraph, + FullyConnectedGraph, + MeshGrid2DGraph, + RingGraph, + StarGraph, + SymmetricExponentialGraph, +) + +from bluefog.common.topology_util import ( + GetDynamicOnePeerSendRecvRanks, + GetExp2DynamicSendRecvMachineRanks, + GetInnerOuterRingDynamicSendRecvRanks, + GetInnerOuterExpo2DynamicSendRecvRanks, +) + + +def _check_ranks(rank_list: List[Any], self_rank: int, size: int) -> [bool, str]: + for rank in rank_list: + if not isinstance(rank, int): + return False, "contain element that is not integer." + if (rank < 0) or (rank >= size): + return False, "contain element that is not between 0 and size-1." + if len(set(rank_list)) != len(rank_list): + return False, "contain duplicated elements." + if self_rank in rank_list: + return False, "contain self rank." + return True, "" + + +def infer_destination_source_ranks( + *, + dst_ranks: Optional[List[int]] = None, + src_ranks: Optional[List[int]] = None, + construct_adjacency_matrix: bool = False, +) -> Union[List[int], np.array]: + """Infer the destination(source) ranks from source(destination) ranks. + + Args: + dst_ranks: A list of destination ranks. If provided the dst_link, a corresponding + src_ranks will be returned. + src_ranks: A list of destination ranks. If provided the src_ranks, a corresponding + dst_link will be returned. + construct_adjacency_matrix: If true, adjacency matrix will be return instead. + Element w_{ij} represents the weights sending from node i to node j. + We use column normalized style, i.e. the sum of receiving weight is 1. + + Raises: + ValueError: If dst_ranks or src_ranks does not contain integer from 0 to size-1. + + Returns: + If construct_adjacency_matrix is false, returns a rank list. + If construct_adjacency_matrix is true, returns a 2-D numpy array. + """ + if dst_ranks is None and src_ranks is None: + raise ValueError("Either dst_ranks or src_ranks need to be provided.") + if dst_ranks is not None and src_ranks is not None: + raise ValueError( + "Only one of two argument dst_ranks or src_ranks should be provided." + ) + + rank_list = dst_ranks or src_ranks + is_valid, error_msg = _check_ranks(rank_list, bf.rank(), bf.size()) + assert is_valid, f"The format of dst_ranks or src_ranks is wrong: {error_msg}" + + degree = len(rank_list) + all_degree_list = bf.allgather(torch.tensor([degree], dtype=torch.int32)).numpy() + all_rank_list = bf.allgather(torch.tensor(rank_list, dtype=torch.int32)).numpy() + adjacency_dict = dict() + displacement = 0 + for i, degree in enumerate(all_degree_list): + adjacency_dict[i] = sorted(all_rank_list[displacement : displacement + degree]) + displacement += degree + + inv_adjacency_dict = collections.defaultdict(list) + for k, adj in adjacency_dict.items(): + for v in adj: + inv_adjacency_dict[v].append(k) + if not construct_adjacency_matrix: + return inv_adjacency_dict.get(bf.rank()) + + # construct_adjacency_matrix + W = np.eye(bf.size()) + for k, adj in adjacency_dict.items(): + W[k, adj] = 1 + if dst_ranks is None: + W = W.T + + return W / W.sum(axis=1) diff --git a/bluefog/torch/utility.py b/bluefog/torch/utility.py index d63b816e..73337207 100644 --- a/bluefog/torch/utility.py +++ b/bluefog/torch/utility.py @@ -212,63 +212,3 @@ def _from_tensor(): for key, p in params: if key in callbacks: callbacks[key]() - - -def _check_ranks(rank_list: List[Any], self_rank: int, size: int) -> [bool, str]: - result = False - for rank in rank_list: - if not isinstance(rank, int): - return False, "contain element that is not integer." - if (rank < 0) or (rank >= size): - return False, "contain element that is not between 0 and size-1." - if len(set(rank_list)) != len(rank_list): - return False, "contain duplicated elements." - if self_rank in rank_list: - return False, "contain self rank." - return True, "" - - -def construct_topology( - *, - dst_list: Optional[List[int]] = None, - src_list: Optional[List[int]] = None, - construct_adjacency_matrix: bool = False -): - """ - - """ - if dst_list is None and src_list is None: - raise ValueError("Either dst_list or src_list need to be provided.") - if dst_list is not None and src_list is not None: - raise ValueError( - "Only one of two argument dst_list or src_list should be provided.") - - rank_list = dst_list or src_list - print(bf.rank() in rank_list) - is_valid, error_msg = _check_ranks(rank_list, bf.rank(), bf.size()) - assert is_valid, f"The format of dst_list or src_list is wrong: {error_msg}" - - degree = len(rank_list) - all_degree_list = bf.allgather(torch.IntTensor([degree])).numpy() - all_rank_list = bf.allgather(torch.IntTensor(rank_list)).numpy() - adjacency_dict = dict() - displacement = 0 - for i, degree in enumerate(all_degree_list): - adjacency_dict[i] = sorted(all_rank_list[displacement:displacement+degree]) - displacement += degree - - inv_adjacency_dict = collections.defaultdict(list) - for k, adj in adjacency_dict.items(): - for v in adj: - inv_adjacency_dict[v].append(k) - if not construct_adjacency_matrix: - return inv_adjacency_dict.get(bf.rank()) - - # construct_adjacency_matrix - W = np.eye(bf.size()) - for k, adj in adjacency_dict.items(): - W[k, adj] = 1 - if dst_list is None: - W = W.T - - return W / W.sum(axis=0) \ No newline at end of file diff --git a/test/torch_basics_test.py b/test/torch_basics_test.py index 4494f3d2..ecb9d7c5 100644 --- a/test/torch_basics_test.py +++ b/test/torch_basics_test.py @@ -28,8 +28,15 @@ from common import mpi_env_rank_and_size import bluefog.torch as bf -from bluefog.common.topology_util import ExponentialGraph, RingGraph, RingGraph -from bluefog.common.topology_util import IsTopologyEquivalent +from bluefog.torch.topology_util import ( + ExponentialGraph, + RingGraph, + StarGraph, + MeshGrid2DGraph, + FullyConnectedGraph, +) +from bluefog.torch.topology_util import IsTopologyEquivalent +from bluefog.torch.topology_util import infer_destination_source_ranks warnings.filterwarnings("ignore", message="numpy.dtype size changed") warnings.filterwarnings("ignore", message="numpy.ufunc size changed") @@ -75,10 +82,12 @@ def test_set_topology_fail_with_win_create(self): if size == 1: expected_topology = nx.from_numpy_array( - np.array([[0.5]]), create_using=nx.DiGraph) + np.array([[0.5]]), create_using=nx.DiGraph + ) elif size == 2: expected_topology = nx.from_numpy_array( - np.array([[0, 0.2], [0.2, 0]]), create_using=nx.DiGraph) + np.array([[0, 0.2], [0.2, 0]]), create_using=nx.DiGraph + ) else: expected_topology = RingGraph(size) @@ -96,10 +105,16 @@ def test_set_and_load_topology(self): bf.init() size = bf.size() if size == 4: - expected_topology = nx.DiGraph(np.array( - [[1/3., 1/3., 1/3., 0.], [0., 1/3., 1/3., 1/3.], - [1/3., 0., 1/3., 1/3.], [1/3., 1/3., 0., 1/3.]] - )) + expected_topology = nx.DiGraph( + np.array( + [ + [1 / 3.0, 1 / 3.0, 1 / 3.0, 0.0], + [0.0, 1 / 3.0, 1 / 3.0, 1 / 3.0], + [1 / 3.0, 0.0, 1 / 3.0, 1 / 3.0], + [1 / 3.0, 1 / 3.0, 0.0, 1 / 3.0], + ] + ) + ) elif size == 1: expected_topology = nx.DiGraph(np.array([[1.0]])) else: @@ -113,15 +128,13 @@ def test_in_out_neighbors_expo2(self): rank = bf.rank() size = bf.size() assert bf.set_topology(ExponentialGraph(size)) - in_neighobrs = bf.in_neighbor_ranks() + in_neighbors = bf.in_neighbor_ranks() out_neighbors = bf.out_neighbor_ranks() degree = int(np.ceil(np.log2(size))) - expected_in_neighbors = sorted([(rank - 2**i) % - size for i in range(degree)]) - expected_out_neighbors = sorted([(rank + 2**i) % - size for i in range(degree)]) - assert sorted(in_neighobrs) == expected_in_neighbors + expected_in_neighbors = sorted([(rank - 2 ** i) % size for i in range(degree)]) + expected_out_neighbors = sorted([(rank + 2 ** i) % size for i in range(degree)]) + assert sorted(in_neighbors) == expected_in_neighbors assert sorted(out_neighbors) == expected_out_neighbors def test_in_out_neighbors_biring(self): @@ -129,21 +142,50 @@ def test_in_out_neighbors_biring(self): rank = bf.rank() size = bf.size() assert bf.set_topology(RingGraph(size)) - in_neighobrs = bf.in_neighbor_ranks() + in_neighbors = bf.in_neighbor_ranks() out_neighbors = bf.out_neighbor_ranks() - expected_in_neighbors = list(set( - map(lambda x: x % size, [rank - 1, rank + 1]))) - expected_out_neighbors = list(set( - map(lambda x: x % size, [rank - 1, rank + 1]))) + expected_in_neighbors = list(set(map(lambda x: x % size, [rank - 1, rank + 1]))) + expected_out_neighbors = list( + set(map(lambda x: x % size, [rank - 1, rank + 1])) + ) if size <= 1: expected_in_neighbors = [] expected_out_neighbors = [] - assert sorted(in_neighobrs) == expected_in_neighbors + assert sorted(in_neighbors) == expected_in_neighbors assert sorted(out_neighbors) == expected_out_neighbors +@pytest.mark.parametrize( + "topo_func", + [ExponentialGraph, RingGraph, StarGraph, MeshGrid2DGraph, FullyConnectedGraph], +) +def test_infer_destination_source_ranks(topo_func): + bf.init() + size = bf.size() + bf.set_topology(topo_func(size)) + topo = bf.load_topology() + in_neighbors = bf.in_neighbor_ranks() + out_neighbors = bf.out_neighbor_ranks() + + src_ranks = infer_destination_source_ranks(dst_ranks=out_neighbors) + assert sorted(src_ranks) == in_neighbors + dst_ranks = infer_destination_source_ranks(src_ranks=in_neighbors) + assert sorted(dst_ranks) == out_neighbors + + W = infer_destination_source_ranks( + dst_ranks=out_neighbors, construct_adjacency_matrix=True + ) + expected_W = (nx.to_numpy_array(topo) > 0).astype(float) + expected_W /= expected_W.sum(axis=0) + np.testing.assert_allclose(W, expected_W) + W = infer_destination_source_ranks( + src_ranks=in_neighbors, construct_adjacency_matrix=True + ) + np.testing.assert_allclose(W, expected_W) + + if __name__ == "__main__": unittest.main() From 57f5cae1ec8624a83c03145f101da166f58aba69 Mon Sep 17 00:00:00 2001 From: ybc Date: Tue, 2 Mar 2021 20:35:04 -0800 Subject: [PATCH 3/6] Address comments --- bluefog/torch/topology_util.py | 34 ++++++---------------------------- examples/construct_topo.py | 26 ++++++++++++++++++++++++++ test/torch_basics_test.py | 21 ++++++++++----------- 3 files changed, 42 insertions(+), 39 deletions(-) create mode 100644 examples/construct_topo.py diff --git a/bluefog/torch/topology_util.py b/bluefog/torch/topology_util.py index dbafc610..860e7666 100644 --- a/bluefog/torch/topology_util.py +++ b/bluefog/torch/topology_util.py @@ -5,30 +5,6 @@ import torch import bluefog.torch as bf -from bluefog.common.topology_util import ( - GetRecvWeights, - GetSendWeights, - IsRegularGraph, - IsTopologyEquivalent, -) - -from bluefog.common.topology_util import ( - ExponentialTwoGraph, - ExponentialGraph, - FullyConnectedGraph, - MeshGrid2DGraph, - RingGraph, - StarGraph, - SymmetricExponentialGraph, -) - -from bluefog.common.topology_util import ( - GetDynamicOnePeerSendRecvRanks, - GetExp2DynamicSendRecvMachineRanks, - GetInnerOuterRingDynamicSendRecvRanks, - GetInnerOuterExpo2DynamicSendRecvRanks, -) - def _check_ranks(rank_list: List[Any], self_rank: int, size: int) -> [bool, str]: for rank in rank_list: @@ -43,7 +19,7 @@ def _check_ranks(rank_list: List[Any], self_rank: int, size: int) -> [bool, str] return True, "" -def infer_destination_source_ranks( +def InferDestinationSourceRanks( *, dst_ranks: Optional[List[int]] = None, src_ranks: Optional[List[int]] = None, @@ -57,7 +33,7 @@ def infer_destination_source_ranks( src_ranks: A list of destination ranks. If provided the src_ranks, a corresponding dst_link will be returned. construct_adjacency_matrix: If true, adjacency matrix will be return instead. - Element w_{ij} represents the weights sending from node i to node j. + Element w_{ij} represents the weights sending from node i to node j. We use column normalized style, i.e. the sum of receiving weight is 1. Raises: @@ -91,8 +67,10 @@ def infer_destination_source_ranks( for k, adj in adjacency_dict.items(): for v in adj: inv_adjacency_dict[v].append(k) + return_list = inv_adjacency_dict.get(bf.rank()) + if not construct_adjacency_matrix: - return inv_adjacency_dict.get(bf.rank()) + return return_list # construct_adjacency_matrix W = np.eye(bf.size()) @@ -101,4 +79,4 @@ def infer_destination_source_ranks( if dst_ranks is None: W = W.T - return W / W.sum(axis=1) + return return_list, W / W.sum(axis=1) diff --git a/examples/construct_topo.py b/examples/construct_topo.py new file mode 100644 index 00000000..a50f202b --- /dev/null +++ b/examples/construct_topo.py @@ -0,0 +1,26 @@ +import bluefog.torch as bf +# from bluefog.common import topology_util +import networkx as nx +bf.init() +# dst_list = [i for i in range(bf.size()) if i != bf.rank()] + +# dst_list = [(bf.rank()+1)%bf.size(), (bf.rank()+3)%bf.size()] + +if bf.rank() == 0: + dst_list = [1, 2] +elif bf.rank() == 1: + dst_list = [0] +elif bf.rank() == 3: + dst_list = [0] +else: + dst_list = [0] + +bf.set_topology(bf.topology_util.ExponentialTwoGraph(size=4)) +print(f"{bf.rank()}: {dst_list}") +# print(bf.rank() in dst_list) +W = bf.infer_destination_source_ranks( + src_list=dst_list, construct_adjacency_matrix=True) +G = nx.from_numpy_array(W, create_using=nx.DiGraph) +if bf.rank() == 0: + print(W) + print(bf.GetRecvWeights(G, bf.rank())) diff --git a/test/torch_basics_test.py b/test/torch_basics_test.py index ecb9d7c5..0417a770 100644 --- a/test/torch_basics_test.py +++ b/test/torch_basics_test.py @@ -28,15 +28,14 @@ from common import mpi_env_rank_and_size import bluefog.torch as bf -from bluefog.torch.topology_util import ( +from bluefog.torch import ( ExponentialGraph, RingGraph, StarGraph, MeshGrid2DGraph, FullyConnectedGraph, ) -from bluefog.torch.topology_util import IsTopologyEquivalent -from bluefog.torch.topology_util import infer_destination_source_ranks +from bluefog.torch import IsTopologyEquivalent, infer_destination_source_ranks warnings.filterwarnings("ignore", message="numpy.dtype size changed") warnings.filterwarnings("ignore", message="numpy.ufunc size changed") @@ -170,20 +169,20 @@ def test_infer_destination_source_ranks(topo_func): in_neighbors = bf.in_neighbor_ranks() out_neighbors = bf.out_neighbor_ranks() - src_ranks = infer_destination_source_ranks(dst_ranks=out_neighbors) - assert sorted(src_ranks) == in_neighbors - dst_ranks = infer_destination_source_ranks(src_ranks=in_neighbors) - assert sorted(dst_ranks) == out_neighbors + # Make the W into average rule. + expected_W = (nx.to_numpy_array(topo) > 0).astype(float) + expected_W /= expected_W.sum(axis=0) - W = infer_destination_source_ranks( + src_ranks, W = infer_destination_source_ranks( dst_ranks=out_neighbors, construct_adjacency_matrix=True ) - expected_W = (nx.to_numpy_array(topo) > 0).astype(float) - expected_W /= expected_W.sum(axis=0) + assert sorted(src_ranks) == in_neighbors np.testing.assert_allclose(W, expected_W) - W = infer_destination_source_ranks( + + dst_ranks, W = infer_destination_source_ranks( src_ranks=in_neighbors, construct_adjacency_matrix=True ) + assert sorted(dst_ranks) == out_neighbors np.testing.assert_allclose(W, expected_W) From 411d4dedebecbd83c8f29fb1a0f43056c1f633da Mon Sep 17 00:00:00 2001 From: ybc Date: Tue, 2 Mar 2021 21:02:27 -0800 Subject: [PATCH 4/6] Split infer_topo into two functions --- bluefog/torch/__init__.py | 17 ++++++--- bluefog/torch/topology_util.py | 70 +++++++++++++++++++++++----------- test/torch_basics_test.py | 37 ++++++++++++++---- 3 files changed, 88 insertions(+), 36 deletions(-) diff --git a/bluefog/torch/__init__.py b/bluefog/torch/__init__.py index cf06b041..eb8a1d08 100644 --- a/bluefog/torch/__init__.py +++ b/bluefog/torch/__init__.py @@ -26,10 +26,10 @@ DistributedWinPutOptimizer, DistributedAllreduceOptimizer, DistributedNeighborAllreduceOptimizer, - DistributedHierarchicalNeighborAllreduceOptimizer + DistributedHierarchicalNeighborAllreduceOptimizer, ) -check_extension('bluefog.torch', __file__, 'mpi_lib') +check_extension("bluefog.torch", __file__, "mpi_lib") from bluefog.torch.mpi_ops import init, shutdown from bluefog.torch.mpi_ops import size, local_size, rank, local_rank @@ -71,7 +71,11 @@ from bluefog.torch.mpi_ops import timeline_start_activity, timeline_end_activity from bluefog.torch.mpi_ops import timeline_context -from bluefog.torch.utility import broadcast_optimizer_state, broadcast_parameters, allreduce_parameters +from bluefog.torch.utility import ( + broadcast_optimizer_state, + broadcast_parameters, + allreduce_parameters, +) from bluefog.common.topology_util import ( GetRecvWeights, @@ -94,7 +98,10 @@ GetDynamicOnePeerSendRecvRanks, GetExp2DynamicSendRecvMachineRanks, GetInnerOuterRingDynamicSendRecvRanks, - GetInnerOuterExpo2DynamicSendRecvRanks + GetInnerOuterExpo2DynamicSendRecvRanks, ) -from bluefog.torch.topology_util import infer_destination_source_ranks +from bluefog.torch.topology_util import ( + InferSourceFromDestinationRanks, + InferDestinationFromSourceRanks, +) diff --git a/bluefog/torch/topology_util.py b/bluefog/torch/topology_util.py index 860e7666..c8bff26a 100644 --- a/bluefog/torch/topology_util.py +++ b/bluefog/torch/topology_util.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Tuple, Union import collections import numpy as np @@ -19,19 +19,41 @@ def _check_ranks(rank_list: List[Any], self_rank: int, size: int) -> [bool, str] return True, "" -def InferDestinationSourceRanks( - *, - dst_ranks: Optional[List[int]] = None, - src_ranks: Optional[List[int]] = None, - construct_adjacency_matrix: bool = False, +def InferSourceFromDestinationRanks( + dst_ranks: List[int], construct_adjacency_matrix: bool = False, +) -> Union[List[int], Tuple[List[int], np.array]]: + """Infer the source ranks from destination ranks. This is collective communication call. + + Args: + dst_ranks: A list of destination ranks. + construct_adjacency_matrix: If true, adjacency matrix will be return instead. + Element w_{ij} represents the weights sending from node i to node j. + We use column normalized style, i.e. the sum of receiving weight is 1. + + Raises: + ValueError: If dst_ranks or src_ranks does not contain integer from 0 to size-1. + + Returns: + If construct_adjacency_matrix is false, returns the source ranks list. + If construct_adjacency_matrix is true, returns the the source ranks list + and a 2-D numpy array. + """ + is_valid, error_msg = _check_ranks(dst_ranks, bf.rank(), bf.size()) + assert is_valid, f"The format of dst_ranks is wrong: {error_msg}" + return _infer_topo( + dst_ranks, + transpose=False, + construct_adjacency_matrix=construct_adjacency_matrix, + ) + + +def InferDestinationFromSourceRanks( + src_ranks: List[int], construct_adjacency_matrix: bool = False, ) -> Union[List[int], np.array]: - """Infer the destination(source) ranks from source(destination) ranks. + """Infer the destination ranks from source ranks. This is collective communication call. Args: - dst_ranks: A list of destination ranks. If provided the dst_link, a corresponding - src_ranks will be returned. - src_ranks: A list of destination ranks. If provided the src_ranks, a corresponding - dst_link will be returned. + src_ranks: A list of destination ranks. construct_adjacency_matrix: If true, adjacency matrix will be return instead. Element w_{ij} represents the weights sending from node i to node j. We use column normalized style, i.e. the sum of receiving weight is 1. @@ -40,20 +62,22 @@ def InferDestinationSourceRanks( ValueError: If dst_ranks or src_ranks does not contain integer from 0 to size-1. Returns: - If construct_adjacency_matrix is false, returns a rank list. - If construct_adjacency_matrix is true, returns a 2-D numpy array. + If construct_adjacency_matrix is false, returns the destination ranks list. + If construct_adjacency_matrix is true, returns the the sodestinationrce ranks + list and a 2-D numpy array. """ - if dst_ranks is None and src_ranks is None: - raise ValueError("Either dst_ranks or src_ranks need to be provided.") - if dst_ranks is not None and src_ranks is not None: - raise ValueError( - "Only one of two argument dst_ranks or src_ranks should be provided." - ) + is_valid, error_msg = _check_ranks(src_ranks, bf.rank(), bf.size()) + assert is_valid, f"The format of src_ranks is wrong: {error_msg}" + return _infer_topo( + src_ranks, + transpose=True, + construct_adjacency_matrix=construct_adjacency_matrix, + ) - rank_list = dst_ranks or src_ranks - is_valid, error_msg = _check_ranks(rank_list, bf.rank(), bf.size()) - assert is_valid, f"The format of dst_ranks or src_ranks is wrong: {error_msg}" +def _infer_topo( + rank_list: List[int], transpose: bool, construct_adjacency_matrix: bool +): degree = len(rank_list) all_degree_list = bf.allgather(torch.tensor([degree], dtype=torch.int32)).numpy() all_rank_list = bf.allgather(torch.tensor(rank_list, dtype=torch.int32)).numpy() @@ -76,7 +100,7 @@ def InferDestinationSourceRanks( W = np.eye(bf.size()) for k, adj in adjacency_dict.items(): W[k, adj] = 1 - if dst_ranks is None: + if transpose: W = W.T return return_list, W / W.sum(axis=1) diff --git a/test/torch_basics_test.py b/test/torch_basics_test.py index 0417a770..77c6904e 100644 --- a/test/torch_basics_test.py +++ b/test/torch_basics_test.py @@ -35,7 +35,11 @@ MeshGrid2DGraph, FullyConnectedGraph, ) -from bluefog.torch import IsTopologyEquivalent, infer_destination_source_ranks +from bluefog.torch import ( + IsTopologyEquivalent, + InferDestinationFromSourceRanks, + InferSourceFromDestinationRanks, +) warnings.filterwarnings("ignore", message="numpy.dtype size changed") warnings.filterwarnings("ignore", message="numpy.ufunc size changed") @@ -161,7 +165,7 @@ def test_in_out_neighbors_biring(self): "topo_func", [ExponentialGraph, RingGraph, StarGraph, MeshGrid2DGraph, FullyConnectedGraph], ) -def test_infer_destination_source_ranks(topo_func): +def test_infer_destination_from_source_ranks(topo_func): bf.init() size = bf.size() bf.set_topology(topo_func(size)) @@ -173,16 +177,33 @@ def test_infer_destination_source_ranks(topo_func): expected_W = (nx.to_numpy_array(topo) > 0).astype(float) expected_W /= expected_W.sum(axis=0) - src_ranks, W = infer_destination_source_ranks( - dst_ranks=out_neighbors, construct_adjacency_matrix=True + src_ranks, W = InferDestinationFromSourceRanks( + src_ranks=in_neighbors, construct_adjacency_matrix=True ) - assert sorted(src_ranks) == in_neighbors + assert sorted(src_ranks) == out_neighbors np.testing.assert_allclose(W, expected_W) - dst_ranks, W = infer_destination_source_ranks( - src_ranks=in_neighbors, construct_adjacency_matrix=True + +@pytest.mark.parametrize( + "topo_func", + [ExponentialGraph, RingGraph, StarGraph, MeshGrid2DGraph, FullyConnectedGraph], +) +def test_infer_source_from_destination_ranks(topo_func): + bf.init() + size = bf.size() + bf.set_topology(topo_func(size)) + topo = bf.load_topology() + in_neighbors = bf.in_neighbor_ranks() + out_neighbors = bf.out_neighbor_ranks() + + # Make the W into average rule. + expected_W = (nx.to_numpy_array(topo) > 0).astype(float) + expected_W /= expected_W.sum(axis=0) + + dst_ranks, W = InferSourceFromDestinationRanks( + dst_ranks=out_neighbors, construct_adjacency_matrix=True ) - assert sorted(dst_ranks) == out_neighbors + assert sorted(dst_ranks) == in_neighbors np.testing.assert_allclose(W, expected_W) From e0a8a6fb70490d718f8723497fbf1e5cb2727b14 Mon Sep 17 00:00:00 2001 From: ybc Date: Thu, 4 Mar 2021 21:52:56 -0800 Subject: [PATCH 5/6] Delete construct_topo.py --- examples/construct_topo.py | 26 -------------------------- 1 file changed, 26 deletions(-) delete mode 100644 examples/construct_topo.py diff --git a/examples/construct_topo.py b/examples/construct_topo.py deleted file mode 100644 index a50f202b..00000000 --- a/examples/construct_topo.py +++ /dev/null @@ -1,26 +0,0 @@ -import bluefog.torch as bf -# from bluefog.common import topology_util -import networkx as nx -bf.init() -# dst_list = [i for i in range(bf.size()) if i != bf.rank()] - -# dst_list = [(bf.rank()+1)%bf.size(), (bf.rank()+3)%bf.size()] - -if bf.rank() == 0: - dst_list = [1, 2] -elif bf.rank() == 1: - dst_list = [0] -elif bf.rank() == 3: - dst_list = [0] -else: - dst_list = [0] - -bf.set_topology(bf.topology_util.ExponentialTwoGraph(size=4)) -print(f"{bf.rank()}: {dst_list}") -# print(bf.rank() in dst_list) -W = bf.infer_destination_source_ranks( - src_list=dst_list, construct_adjacency_matrix=True) -G = nx.from_numpy_array(W, create_using=nx.DiGraph) -if bf.rank() == 0: - print(W) - print(bf.GetRecvWeights(G, bf.rank())) From 90a42884d99c6d07f108d3e63db151ec1c1b8640 Mon Sep 17 00:00:00 2001 From: ybc Date: Sun, 28 Mar 2021 18:16:20 -0700 Subject: [PATCH 6/6] Fix the none case in _infer_topo func --- bluefog/torch/topology_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bluefog/torch/topology_util.py b/bluefog/torch/topology_util.py index c8bff26a..b0fe38f2 100644 --- a/bluefog/torch/topology_util.py +++ b/bluefog/torch/topology_util.py @@ -92,6 +92,8 @@ def _infer_topo( for v in adj: inv_adjacency_dict[v].append(k) return_list = inv_adjacency_dict.get(bf.rank()) + if return_list is None: + return_list = [] if not construct_adjacency_matrix: return return_list