From 82ce438b176320eabcd16b7ea0a4d8131952d8d7 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 2 Jun 2022 16:40:59 +0800 Subject: [PATCH 1/3] add 1d send recv in nccl logical --- .../insert_nccl_logical_op_pass.cpp | 16 +++- .../test/graph/test_nccl_logical_send_recv.py | 86 +++++++++++++++++-- 2 files changed, 91 insertions(+), 11 deletions(-) diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index b918f8ea6b8..9b4d73fd131 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -218,6 +218,18 @@ bool TryBuildNcclBy1DHierarchy(OperatorConf* ret, const SbpParallel& src_sbp, .Build() .op_conf(); return true; + } else if (!dst_sbp.has_partial_sum_parallel()) { + *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(Send)2(Recv)-" + + NewUniqueId()) + .Op("_nccl_logical_send_recv") + .Input("in", lbn) + .Output("out") + .Attr>("src_nd_sbp", {SbpToString(src_sbp)}) + .Attr>("dst_nd_sbp", {SbpToString(dst_sbp)}) + .ScopeSymbolId(scope_symbol_id) + .Build() + .op_conf(); + return true; } return false; } @@ -503,7 +515,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode( } if (Global::Get()->enable_debug_mode()) { - VLOG(3) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name + VLOG(2) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name << ", order=" << src_order << ", sbp=" << NdSbpToString(src_node->NdSbp4Lbi(lbi)) << "] to [" << dst_op_name << ", order=" << node2subgraph_order.at(dst_node) << ", sbp=" << NdSbpToString(dst_node->NdSbp4Lbi(lbi)) << "] and before [" @@ -569,7 +581,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode( } if (Global::Get()->enable_debug_mode()) { - VLOG(3) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name + VLOG(2) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name << ", order=" << node2subgraph_order.at(src_node) << "] to [" << dst_op_name << ", order=" << dst_order << "] and after [" << pre_op_name << ", order=" << dst_order - 1 << "]\n"; diff --git a/python/oneflow/test/graph/test_nccl_logical_send_recv.py b/python/oneflow/test/graph/test_nccl_logical_send_recv.py index d685d771229..525257799b1 100644 --- a/python/oneflow/test/graph/test_nccl_logical_send_recv.py +++ b/python/oneflow/test/graph/test_nccl_logical_send_recv.py @@ -29,7 +29,7 @@ os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "1" -def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): +def _test_nccl_logical_send_recv_2d(test_case, src_nd_sbp, dst_nd_sbp): # can not process p in dst if flow.sbp.partial_sum() in dst_nd_sbp: return @@ -66,7 +66,7 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): # check graph boxing flow.boxing.nccl.enable_use_compute_stream(True) - class TestNcclLogicalSendRecvGraph(flow.nn.Graph): + class TestNcclLogicalSendRecv2DGraph(flow.nn.Graph): def __init__(self): super().__init__() @@ -74,7 +74,7 @@ def build(self, x): y = x.to_global(sbp=dst_nd_sbp, placement=placement) return y - graph = TestNcclLogicalSendRecvGraph() + graph = TestNcclLogicalSendRecv2DGraph() # graph.debug() y = graph(x) out_np = y.numpy() @@ -88,7 +88,7 @@ def build(self, x): test_case.assertTrue(np.array_equal(out_np, in_np)) -def gen_nd_sbp(): +def gen_2d_sbp(): sbp_list = [ flow.sbp.partial_sum(), flow.sbp.broadcast(), @@ -105,14 +105,82 @@ def gen_nd_sbp(): @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") -class TestNcclLogicalSendRecv(flow.unittest.TestCase): - def test_nccl_logical_send_recv(test_case): +class TestNcclLogicalSendRecv2D(flow.unittest.TestCase): + def test_nccl_logical_send_recv_2d(test_case): arg_dict = OrderedDict() - arg_dict["src_nd_sbp"] = gen_nd_sbp() - arg_dict["dst_nd_sbp"] = gen_nd_sbp() + arg_dict["src_nd_sbp"] = gen_2d_sbp() + arg_dict["dst_nd_sbp"] = gen_2d_sbp() for arg in GenArgList(arg_dict): - _test_nccl_logical_send_recv(test_case, *arg) + _test_nccl_logical_send_recv_2d(test_case, *arg) +def _test_nccl_logical_send_recv_1d(test_case, src_nd_sbp, dst_nd_sbp): + # can not process p in dst + if flow.sbp.partial_sum() in dst_nd_sbp: + return + + # skip src == dst + if src_nd_sbp == dst_nd_sbp: + return + + # input + placement = flow.placement("cuda", ranks=[0, 1]) + local_np = np.arange(2 * 2 * 2).reshape(2, 2, 2) + x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement) + + # check eager boxing + eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement) + test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) + + # check graph boxing + flow.boxing.nccl.enable_use_compute_stream(True) + + class TestNcclLogicalSendRecv1DGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, x): + y = x.to_global(sbp=dst_nd_sbp, placement=placement) + return y + + graph = TestNcclLogicalSendRecv1DGraph() + #graph.debug(0) + y = graph(x) + out_np = y.numpy() + in_np = x.numpy() + #if flow.env.get_rank() == 0: + # print("src sbp ", src_nd_sbp, ", dst sbp ", dst_nd_sbp) + # print(graph) + # equal = np.array_equal(out_np, in_np) + # if not equal: + # print("in ", in_np) + # print("out ", out_np) + # print("====================") + test_case.assertTrue(np.array_equal(out_np, in_np)) + + +def gen_1d_sbp(): + sbp_list = [ + flow.sbp.partial_sum(), + flow.sbp.broadcast(), + flow.sbp.split(0), + flow.sbp.split(1), + flow.sbp.split(2), + ] + nd_sbp_list = [] + for sbp0 in sbp_list: + nd_sbp_list.append([sbp0,]) + return nd_sbp_list + + +@flow.unittest.skip_unless_1n2d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestNcclLogicalSendRecv1D(flow.unittest.TestCase): + def test_nccl_logical_send_recv_1d(test_case): + arg_dict = OrderedDict() + arg_dict["src_nd_sbp"] = gen_1d_sbp() + arg_dict["dst_nd_sbp"] = gen_1d_sbp() + for arg in GenArgList(arg_dict): + _test_nccl_logical_send_recv_1d(test_case, *arg) if __name__ == "__main__": unittest.main() From 7d27be0d71a2253493c23a3ece2f671473bf9703 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Mon, 6 Jun 2022 15:31:44 +0800 Subject: [PATCH 2/3] Update insert_nccl_logical_op_pass.cpp --- oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index f09d6a38b4c..3bcb04d567b 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -245,7 +245,6 @@ bool TryBuildNcclBy1DHierarchy(OperatorConf* ret, const SbpParallel& src_sbp, .op_conf(); return true; } - return false; } From 3342bb283cfe2ef12c9e225abd3f508938398325 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 16 Jun 2022 05:30:57 +0000 Subject: [PATCH 3/3] auto format by CI --- .../oneflow/test/graph/test_nccl_logical_send_recv.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/oneflow/test/graph/test_nccl_logical_send_recv.py b/python/oneflow/test/graph/test_nccl_logical_send_recv.py index be9d3a1abeb..9b6b90750d8 100644 --- a/python/oneflow/test/graph/test_nccl_logical_send_recv.py +++ b/python/oneflow/test/graph/test_nccl_logical_send_recv.py @@ -109,6 +109,7 @@ def test_nccl_logical_send_recv_2d(test_case): for arg in GenArgList(arg_dict): _test_nccl_logical_send_recv_2d(test_case, *arg) + def _test_nccl_logical_send_recv_1d(test_case, src_nd_sbp, dst_nd_sbp): # can not process p in dst if flow.sbp.partial_sum() in dst_nd_sbp: @@ -139,11 +140,11 @@ def build(self, x): return y graph = TestNcclLogicalSendRecv1DGraph() - #graph.debug(0) + # graph.debug(0) y = graph(x) out_np = y.numpy() in_np = x.numpy() - #if flow.env.get_rank() == 0: + # if flow.env.get_rank() == 0: # print("src sbp ", src_nd_sbp, ", dst sbp ", dst_nd_sbp) # print(graph) # equal = np.array_equal(out_np, in_np) @@ -164,7 +165,9 @@ def gen_1d_sbp(): ] nd_sbp_list = [] for sbp0 in sbp_list: - nd_sbp_list.append([sbp0,]) + nd_sbp_list.append( + [sbp0,] + ) return nd_sbp_list @@ -178,5 +181,6 @@ def test_nccl_logical_send_recv_1d(test_case): for arg in GenArgList(arg_dict): _test_nccl_logical_send_recv_1d(test_case, *arg) + if __name__ == "__main__": unittest.main()