Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add 1d send recv in nccl logical #8355

Merged
merged 9 commits into from
Jun 16, 2022
16 changes: 14 additions & 2 deletions oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,18 @@ bool TryBuildNcclBy1DHierarchy(OperatorConf* ret, const SbpParallel& src_sbp,
.Build()
.op_conf();
return true;
} else if (!dst_sbp.has_partial_sum_parallel()) {
*ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(Send)2(Recv)-"
+ NewUniqueId())
.Op("_nccl_logical_send_recv")
.Input("in", lbn)
.Output("out")
.Attr<std::vector<std::string>>("src_nd_sbp", {SbpToString(src_sbp)})
.Attr<std::vector<std::string>>("dst_nd_sbp", {SbpToString(dst_sbp)})
.ScopeSymbolId(scope_symbol_id)
.Build()
.op_conf();
return true;
}
return false;
}
Expand Down Expand Up @@ -517,7 +529,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode(
}

if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
VLOG(3) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name
VLOG(2) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name
<< ", order=" << src_order << ", sbp=" << NdSbpToString(src_node->NdSbp4Lbi(lbi))
<< "] to [" << dst_op_name << ", order=" << node2subgraph_order.at(dst_node)
<< ", sbp=" << NdSbpToString(dst_node->NdSbp4Lbi(lbi)) << "] and before ["
Expand Down Expand Up @@ -583,7 +595,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(
}

if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
VLOG(3) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name
VLOG(2) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name
<< ", order=" << node2subgraph_order.at(src_node) << "] to [" << dst_op_name
<< ", order=" << dst_order << "] and after [" << pre_op_name
<< ", order=" << dst_order - 1 << "]\n";
Expand Down
90 changes: 81 additions & 9 deletions python/oneflow/test/graph/test_nccl_logical_send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "1"


def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp):
def _test_nccl_logical_send_recv_2d(test_case, src_nd_sbp, dst_nd_sbp):
# can not process p in dst
if flow.sbp.partial_sum() in dst_nd_sbp:
return
Expand Down Expand Up @@ -62,15 +62,15 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp):
# check graph boxing
flow.boxing.nccl.enable_use_compute_stream(True)

class TestNcclLogicalSendRecvGraph(flow.nn.Graph):
class TestNcclLogicalSendRecv2DGraph(flow.nn.Graph):
def __init__(self):
super().__init__()

def build(self, x):
y = x.to_global(sbp=dst_nd_sbp, placement=placement)
return y

graph = TestNcclLogicalSendRecvGraph()
graph = TestNcclLogicalSendRecv2DGraph()
# graph.debug()
y = graph(x)
out_np = y.numpy()
Expand All @@ -84,7 +84,7 @@ def build(self, x):
test_case.assertTrue(np.array_equal(out_np, in_np))


def gen_nd_sbp():
def gen_2d_sbp():
sbp_list = [
flow.sbp.partial_sum(),
flow.sbp.broadcast(),
Expand All @@ -101,13 +101,85 @@ def gen_nd_sbp():

@flow.unittest.skip_unless_1n4d()
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
class TestNcclLogicalSendRecv(flow.unittest.TestCase):
def test_nccl_logical_send_recv(test_case):
class TestNcclLogicalSendRecv2D(flow.unittest.TestCase):
def test_nccl_logical_send_recv_2d(test_case):
arg_dict = OrderedDict()
arg_dict["src_nd_sbp"] = gen_nd_sbp()
arg_dict["dst_nd_sbp"] = gen_nd_sbp()
arg_dict["src_nd_sbp"] = gen_2d_sbp()
arg_dict["dst_nd_sbp"] = gen_2d_sbp()
for arg in GenArgList(arg_dict):
_test_nccl_logical_send_recv(test_case, *arg)
_test_nccl_logical_send_recv_2d(test_case, *arg)


def _test_nccl_logical_send_recv_1d(test_case, src_nd_sbp, dst_nd_sbp):
# can not process p in dst
if flow.sbp.partial_sum() in dst_nd_sbp:
return

# skip src == dst
if src_nd_sbp == dst_nd_sbp:
return

# input
placement = flow.placement("cuda", ranks=[0, 1])
local_np = np.arange(2 * 2 * 2).reshape(2, 2, 2)
x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement)

# check eager boxing
eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement)
test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy()))

# check graph boxing
flow.boxing.nccl.enable_use_compute_stream(True)

class TestNcclLogicalSendRecv1DGraph(flow.nn.Graph):
def __init__(self):
super().__init__()

def build(self, x):
y = x.to_global(sbp=dst_nd_sbp, placement=placement)
return y

graph = TestNcclLogicalSendRecv1DGraph()
# graph.debug(0)
y = graph(x)
out_np = y.numpy()
in_np = x.numpy()
# if flow.env.get_rank() == 0:
# print("src sbp ", src_nd_sbp, ", dst sbp ", dst_nd_sbp)
# print(graph)
# equal = np.array_equal(out_np, in_np)
# if not equal:
# print("in ", in_np)
# print("out ", out_np)
# print("====================")
test_case.assertTrue(np.array_equal(out_np, in_np))


def gen_1d_sbp():
sbp_list = [
flow.sbp.partial_sum(),
flow.sbp.broadcast(),
flow.sbp.split(0),
flow.sbp.split(1),
flow.sbp.split(2),
]
nd_sbp_list = []
for sbp0 in sbp_list:
nd_sbp_list.append(
[sbp0,]
)
return nd_sbp_list


@flow.unittest.skip_unless_1n2d()
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
class TestNcclLogicalSendRecv1D(flow.unittest.TestCase):
def test_nccl_logical_send_recv_1d(test_case):
arg_dict = OrderedDict()
arg_dict["src_nd_sbp"] = gen_1d_sbp()
arg_dict["dst_nd_sbp"] = gen_1d_sbp()
for arg in GenArgList(arg_dict):
_test_nccl_logical_send_recv_1d(test_case, *arg)


if __name__ == "__main__":
Expand Down