Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nd boxing use nccl send/recv #7936

Closed
wants to merge 14 commits into from
124 changes: 112 additions & 12 deletions oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ limitations under the License.
#include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/graph/task_stream_id.h"

namespace oneflow {

Expand Down Expand Up @@ -117,6 +120,27 @@ std::shared_ptr<ChainSubTskGphBuilder> Make1DSubTskGphBuilder() {
return std::make_shared<ChainSubTskGphBuilder>(builders);
}

void MergeParallelConf(const ParallelDesc& parallel_desc_0, const ParallelDesc& parallel_desc_1,
ParallelConf* parallel_conf) {
CHECK_EQ(parallel_desc_0.device_tag(), parallel_desc_1.device_tag());
std::set<std::pair<int64_t, int64_t>> machine_device_ids;
for (int64_t machine_id : parallel_desc_0.sorted_machine_ids()) {
for (int64_t device_id : parallel_desc_0.sorted_dev_phy_ids(machine_id)) {
machine_device_ids.insert(std::make_pair(machine_id, device_id));
}
}
for (int64_t machine_id : parallel_desc_1.sorted_machine_ids()) {
for (int64_t device_id : parallel_desc_1.sorted_dev_phy_ids(machine_id)) {
machine_device_ids.insert(std::make_pair(machine_id, device_id));
}
}
parallel_conf->set_device_tag(parallel_desc_0.device_tag());
for (const auto& pair : machine_device_ids) {
parallel_conf->add_device_name("@" + std::to_string(pair.first) + ":"
+ std::to_string(pair.second));
}
}

} // namespace

void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc,
Expand Down Expand Up @@ -171,6 +195,66 @@ class FlatSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {
std::shared_ptr<SubTskGphBuilder> sub_tsk_gph_builder_;
};

class NDNcclSendRecvBoxingSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {
public:
OF_DISALLOW_COPY_AND_MOVE(NDNcclSendRecvBoxingSubTskGphBuilder);
NDNcclSendRecvBoxingSubTskGphBuilder() {}
~NDNcclSendRecvBoxingSubTskGphBuilder() override = default;

Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<TaskNode*>& sorted_in_tasks,
std::vector<TaskNode*>* sorted_out_tasks,
std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,
const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,
const Shape& time_shape) const override {
if (in_parallel_desc.device_type() == DeviceType::kCUDA
&& out_parallel_desc.device_type() == DeviceType::kCUDA
&& !NdSbpHasPartialParallel(out_nd_sbp)) {
#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700
ParallelConf merged_parallel_conf;
MergeParallelConf(in_parallel_desc.parallel_conf(), out_parallel_desc.parallel_conf(),
&merged_parallel_conf);
ParallelDesc merged_parallel_desc(merged_parallel_conf);
TaskNode* first_in_node = sorted_in_tasks.front();
sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num());
FOR_RANGE(int64_t, id, 0, merged_parallel_desc.parallel_num()) {
NcclSendRecvBoxingTaskNode* node = ctx->task_graph()->NewNode<NcclSendRecvBoxingTaskNode>();
const int64_t machine_id = JUST(merged_parallel_desc.MachineId4ParallelId(id));
int64_t device_index = JUST(merged_parallel_desc.DeviceId4ParallelId(id));
int64_t thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId(
machine_id, merged_parallel_desc.device_type(), device_index, "NCCL_SEND_RECV_BOXING"));
bool has_input = in_parallel_desc.Containing(machine_id, device_index);
bool has_output = out_parallel_desc.Containing(machine_id, device_index);
node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(),
logical_blob_desc.data_type(), in_nd_sbp, out_nd_sbp, in_parallel_desc,
out_parallel_desc, id, merged_parallel_desc, has_input, has_output);
if (has_input) {
int64_t in_id =
JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index));
ctx->task_graph()->ConnectWithLbi(sorted_in_tasks.at(in_id), node, lbi);
} else {
// TODO: find nearest
std::string regst_desc_name;
first_in_node->BuildCtrlRegstDesc(node, &regst_desc_name);
TaskEdge* edge = ctx->task_graph()->NewEdge();
Connect<TaskNode>(first_in_node, edge, node);
first_in_node->BindEdgeWithProducedRegst(edge, regst_desc_name);
}
if (has_output) { sorted_out_tasks->push_back(node); }
}
return BuildSubTskGphBuilderStatus("NDNcclSendRecvBoxingSubTskGphBuilder", "");
#else
return Error::BoxingNotSupportedError();
#endif
} else {
return Error::BoxingNotSupportedError();
}
}
};

class IntraGroupSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {
public:
OF_DISALLOW_COPY_AND_MOVE(IntraGroupSubTskGphBuilder);
Expand Down Expand Up @@ -350,21 +434,22 @@ class Dim0NdSbpMismatchedSubTskGphBuilder final : public HierarchicalSubTskGphBu
if (in_parallel_desc.hierarchy()->NumAxes() == 2
&& (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy())
&& in_nd_sbp.sbp_parallel(0) != out_nd_sbp.sbp_parallel(0)
&& in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1)) {
if (!(NdSbpAllSameSplitParallel(in_nd_sbp) || NdSbpAllSameSplitParallel(out_nd_sbp))) {
return inter_group_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,
out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape);
} else {
return Error::BoxingNotSupportedError();
}
&& in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1)
&& !(NdSbpAllSameSplitParallel(in_nd_sbp) || NdSbpAllSameSplitParallel(out_nd_sbp))) {
return inter_group_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,
out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape);
} else {
return Error::BoxingNotSupportedError();
return nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,
out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape);
}
}

private:
std::unique_ptr<InterGroupSubTskGphBuilder> inter_group_sub_tsk_gph_builder_;
std::unique_ptr<NDNcclSendRecvBoxingSubTskGphBuilder>
nd_nccl_send_recv_boxing_sub_tsk_gph_builder_;
};

class Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilder {
Expand All @@ -391,12 +476,10 @@ class Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilde
return intra_group_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,
out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape);
} else if (in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1)) {
} else {
return dim0_nd_sbp_mismatched_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,
out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape);
} else {
return Error::BoxingNotSupportedError();
}
} else {
return Error::BoxingNotSupportedError();
Expand Down Expand Up @@ -464,13 +547,16 @@ struct DispatchHierarchicalSubTskGphBuilder::Impl {
std::unique_ptr<Same2DHierarchySubTskGphBuilder> same_2d_hierarchy_sub_tsk_gph_builder_;
std::unique_ptr<ExpandToSame2DHierarchySubTskGphBuilder>
expand_to_same_2d_hierarchy_sub_tsk_gph_builder_;
std::unique_ptr<NDNcclSendRecvBoxingSubTskGphBuilder>
nd_nccl_send_recv_boxing_sub_tsk_gph_builder_;
};

DispatchHierarchicalSubTskGphBuilder::Impl::Impl() {
flat_sub_tsk_gph_builder_.reset(new FlatSubTskGphBuilder());
same_2d_hierarchy_sub_tsk_gph_builder_.reset(new Same2DHierarchySubTskGphBuilder());
expand_to_same_2d_hierarchy_sub_tsk_gph_builder_.reset(
new ExpandToSame2DHierarchySubTskGphBuilder());
nd_nccl_send_recv_boxing_sub_tsk_gph_builder_.reset(new NDNcclSendRecvBoxingSubTskGphBuilder());
}

DispatchHierarchicalSubTskGphBuilder::DispatchHierarchicalSubTskGphBuilder() {
Expand All @@ -495,6 +581,14 @@ Maybe<SubTskGphBuilderStatus> DispatchHierarchicalSubTskGphBuilder::Build(
&reduced_out_nd_sbp);
const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy();
const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy();
if ((in_hierarchy->NumAxes() > 2 || out_hierarchy->NumAxes() > 2)
&& reduced_in_parallel_desc.device_type() == DeviceType::kCUDA
&& reduced_out_parallel_desc.device_type() == DeviceType::kCUDA) {
return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,
reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp,
time_shape);
}
if (in_hierarchy->NumAxes() <= 2 && out_hierarchy->NumAxes() <= 2) {
if (in_hierarchy->NumAxes() == 1 && out_hierarchy->NumAxes() == 1) {
return impl_->flat_sub_tsk_gph_builder_->Build(
Expand All @@ -513,6 +607,12 @@ Maybe<SubTskGphBuilderStatus> DispatchHierarchicalSubTskGphBuilder::Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,
reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp,
time_shape);
} else if (reduced_in_parallel_desc.device_type() == DeviceType::kCUDA
&& reduced_out_parallel_desc.device_type() == DeviceType::kCUDA) {
return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,
reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp,
time_shape);
} else {
return Error::BoxingNotSupportedError();
}
Expand Down
92 changes: 92 additions & 0 deletions oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h"

namespace oneflow {

void NcclSendRecvBoxingTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,
const Shape& logical_shape, const DataType& data_type,
const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const int64_t parallel_id, const ParallelDesc& parallel_desc,
const bool has_input, const bool has_output) {
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_lbi(lbi);
logical_shape_ = logical_shape;
src_nd_sbp_ = src_nd_sbp;
dst_nd_sbp_ = dst_nd_sbp;
src_parallel_conf_ = src_parallel_desc.parallel_conf();
dst_parallel_conf_ = dst_parallel_desc.parallel_conf();
parallel_conf_ = parallel_desc.parallel_conf();
parallel_ctx_.set_parallel_id(parallel_id);
parallel_ctx_.set_parallel_num(parallel_desc.parallel_num());
has_input_ = has_input;
has_output_ = has_output;
data_type_ = data_type;
}

void NcclSendRecvBoxingTaskNode::ProduceAllRegstsAndBindEdges() {
if (has_output_) {
std::shared_ptr<RegstDesc> out_regst = ProduceRegst("out", true, 1, 1);
this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); });
}
ProduceRegst("tmp", true);
}

void NcclSendRecvBoxingTaskNode::ConsumeAllRegsts() {
this->ForEachInDataEdge(
[&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); });
}

void NcclSendRecvBoxingTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
OperatorConf op_conf;
op_conf.set_name("System-Nccl-Send-Recv-Boxing-" + NewUniqueId());
op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
auto* nccl_send_recv_boxing_conf = op_conf.mutable_nccl_send_recv_boxing_conf();
*nccl_send_recv_boxing_conf->mutable_lbi() = lbi();
logical_shape_.ToProto(nccl_send_recv_boxing_conf->mutable_logical_shape());
nccl_send_recv_boxing_conf->set_data_type(data_type_);
*nccl_send_recv_boxing_conf->mutable_src_nd_sbp() = src_nd_sbp_;
*nccl_send_recv_boxing_conf->mutable_dst_nd_sbp() = dst_nd_sbp_;
*nccl_send_recv_boxing_conf->mutable_parallel_conf() = parallel_conf_;
*nccl_send_recv_boxing_conf->mutable_src_parallel_conf() = src_parallel_conf_;
*nccl_send_recv_boxing_conf->mutable_dst_parallel_conf() = dst_parallel_conf_;
nccl_send_recv_boxing_conf->set_has_input(has_input_);
nccl_send_recv_boxing_conf->set_has_output(has_output_);
std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));
node->mut_op() = sole_op;
if (has_input_) { node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); }
if (has_output_) {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
}
node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp"));
node->InferBlobDescs(parallel_ctx());
}

void NcclSendRecvBoxingTaskNode::InferProducedDataRegstTimeShape() {
auto out_regst = GetProducedRegst("out");
if (out_regst != nullptr) { out_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); }
auto tmp_regst = GetProducedRegst("tmp");
tmp_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1}));
}

} // namespace oneflow
57 changes: 57 additions & 0 deletions oneflow/core/graph/nccl_send_recv_boxing_task_node.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_

#include "oneflow/core/graph/transport_task_node.h"

namespace oneflow {

class NcclSendRecvBoxingTaskNode : public TransportTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingTaskNode);
NcclSendRecvBoxingTaskNode() = default;
~NcclSendRecvBoxingTaskNode() override = default;

void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,
const Shape& logical_shape, const DataType& data_type, const NdSbp& src_nd_sbp,
const NdSbp& dst_nd_sbp, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const int64_t parallel_id,
const ParallelDesc& parallel_desc, const bool has_input, const bool has_output);
TaskType GetTaskType() const override { return TaskType::kNcclSendRecvBoxing; }
const ParallelContext* parallel_ctx() const override { return &parallel_ctx_; }

private:
void BuildExecGphAndRegst() override;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() final;
void InferProducedDataRegstTimeShape() final;

Shape logical_shape_;
DataType data_type_;
NdSbp src_nd_sbp_;
NdSbp dst_nd_sbp_;
ParallelConf src_parallel_conf_;
ParallelConf dst_parallel_conf_;
ParallelConf parallel_conf_;
ParallelContext parallel_ctx_;
bool has_input_;
bool has_output_;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_
6 changes: 6 additions & 0 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,12 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
const ParallelDesc& src_parallel_desc = src_op_node->parallel_desc();
const ParallelDesc& dst_parallel_desc = dst_op_node->parallel_desc();
const BlobDesc& blob_desc = src_op_node->LogicalBlobDesc4Lbi(lbi);
VLOG(3) << "src op: " << src_op_node->op().op_name()
<< " dst op: " << dst_op_node->op().op_name()
<< " src_parallel_conf: " << src_parallel_desc.parallel_conf().DebugString()
<< " dst parallel conf: " << dst_parallel_desc.parallel_conf().DebugString()
<< " src_nd_sbp " << src_nd_sbp.DebugString() << " dst nd_sbp "
<< dst_nd_sbp.DebugString();
auto status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build(
sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, src_parallel_desc,
dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp,
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ enum TaskType {
kSspVariableProxy = 63;
kBoxingZeros = 64;
kCriticalSectionWaitTick = 65;
kNcclSendRecvBoxing = 66;
};

message RegstDescIdSet {
Expand Down
Loading