diff --git a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp index 03e7c6529e7..0845a1b5b02 100644 --- a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp +++ b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp @@ -26,6 +26,9 @@ limitations under the License. #include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/job/sbp_parallel.h" +#include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h" +#include "oneflow/core/job/nd_sbp_util.h" +#include "oneflow/core/graph/task_stream_id.h" namespace oneflow { @@ -117,6 +120,27 @@ std::shared_ptr Make1DSubTskGphBuilder() { return std::make_shared(builders); } +void MergeParallelConf(const ParallelDesc& parallel_desc_0, const ParallelDesc& parallel_desc_1, + ParallelConf* parallel_conf) { + CHECK_EQ(parallel_desc_0.device_tag(), parallel_desc_1.device_tag()); + std::set> machine_device_ids; + for (int64_t machine_id : parallel_desc_0.sorted_machine_ids()) { + for (int64_t device_id : parallel_desc_0.sorted_dev_phy_ids(machine_id)) { + machine_device_ids.insert(std::make_pair(machine_id, device_id)); + } + } + for (int64_t machine_id : parallel_desc_1.sorted_machine_ids()) { + for (int64_t device_id : parallel_desc_1.sorted_dev_phy_ids(machine_id)) { + machine_device_ids.insert(std::make_pair(machine_id, device_id)); + } + } + parallel_conf->set_device_tag(parallel_desc_0.device_tag()); + for (const auto& pair : machine_device_ids) { + parallel_conf->add_device_name("@" + std::to_string(pair.first) + ":" + + std::to_string(pair.second)); + } +} + } // namespace void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc, @@ -171,6 +195,66 @@ class FlatSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { std::shared_ptr sub_tsk_gph_builder_; }; +class NDNcclSendRecvBoxingSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { + public: + OF_DISALLOW_COPY_AND_MOVE(NDNcclSendRecvBoxingSubTskGphBuilder); + NDNcclSendRecvBoxingSubTskGphBuilder() {} + ~NDNcclSendRecvBoxingSubTskGphBuilder() override = default; + + Maybe Build(SubTskGphBuilderCtx* ctx, + const std::vector& sorted_in_tasks, + std::vector* sorted_out_tasks, + std::vector>* sorted_ctrl_tasks, + const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, + const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, + const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, + const Shape& time_shape) const override { + if (in_parallel_desc.device_type() == DeviceType::kCUDA + && out_parallel_desc.device_type() == DeviceType::kCUDA + && !NdSbpHasPartialParallel(out_nd_sbp)) { +#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 + ParallelConf merged_parallel_conf; + MergeParallelConf(in_parallel_desc.parallel_conf(), out_parallel_desc.parallel_conf(), + &merged_parallel_conf); + ParallelDesc merged_parallel_desc(merged_parallel_conf); + TaskNode* first_in_node = sorted_in_tasks.front(); + sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num()); + FOR_RANGE(int64_t, id, 0, merged_parallel_desc.parallel_num()) { + NcclSendRecvBoxingTaskNode* node = ctx->task_graph()->NewNode(); + const int64_t machine_id = JUST(merged_parallel_desc.MachineId4ParallelId(id)); + int64_t device_index = JUST(merged_parallel_desc.DeviceId4ParallelId(id)); + int64_t thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId( + machine_id, merged_parallel_desc.device_type(), device_index, "NCCL_SEND_RECV_BOXING")); + bool has_input = in_parallel_desc.Containing(machine_id, device_index); + bool has_output = out_parallel_desc.Containing(machine_id, device_index); + node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), + logical_blob_desc.data_type(), in_nd_sbp, out_nd_sbp, in_parallel_desc, + out_parallel_desc, id, merged_parallel_desc, has_input, has_output); + if (has_input) { + int64_t in_id = + JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + ctx->task_graph()->ConnectWithLbi(sorted_in_tasks.at(in_id), node, lbi); + } else { + // TODO: find nearest + std::string regst_desc_name; + first_in_node->BuildCtrlRegstDesc(node, ®st_desc_name); + TaskEdge* edge = ctx->task_graph()->NewEdge(); + Connect(first_in_node, edge, node); + first_in_node->BindEdgeWithProducedRegst(edge, regst_desc_name); + } + if (has_output) { sorted_out_tasks->push_back(node); } + } + return BuildSubTskGphBuilderStatus("NDNcclSendRecvBoxingSubTskGphBuilder", ""); +#else + return Error::BoxingNotSupportedError(); +#endif + } else { + return Error::BoxingNotSupportedError(); + } + } +}; + class IntraGroupSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(IntraGroupSubTskGphBuilder); @@ -350,21 +434,22 @@ class Dim0NdSbpMismatchedSubTskGphBuilder final : public HierarchicalSubTskGphBu if (in_parallel_desc.hierarchy()->NumAxes() == 2 && (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy()) && in_nd_sbp.sbp_parallel(0) != out_nd_sbp.sbp_parallel(0) - && in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1)) { - if (!(NdSbpAllSameSplitParallel(in_nd_sbp) || NdSbpAllSameSplitParallel(out_nd_sbp))) { - return inter_group_sub_tsk_gph_builder_->Build( - ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, - out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); - } else { - return Error::BoxingNotSupportedError(); - } + && in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1) + && !(NdSbpAllSameSplitParallel(in_nd_sbp) || NdSbpAllSameSplitParallel(out_nd_sbp))) { + return inter_group_sub_tsk_gph_builder_->Build( + ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, + out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); } else { - return Error::BoxingNotSupportedError(); + return nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( + ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, + out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); } } private: std::unique_ptr inter_group_sub_tsk_gph_builder_; + std::unique_ptr + nd_nccl_send_recv_boxing_sub_tsk_gph_builder_; }; class Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilder { @@ -391,12 +476,10 @@ class Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilde return intra_group_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); - } else if (in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1)) { + } else { return dim0_nd_sbp_mismatched_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); - } else { - return Error::BoxingNotSupportedError(); } } else { return Error::BoxingNotSupportedError(); @@ -464,6 +547,8 @@ struct DispatchHierarchicalSubTskGphBuilder::Impl { std::unique_ptr same_2d_hierarchy_sub_tsk_gph_builder_; std::unique_ptr expand_to_same_2d_hierarchy_sub_tsk_gph_builder_; + std::unique_ptr + nd_nccl_send_recv_boxing_sub_tsk_gph_builder_; }; DispatchHierarchicalSubTskGphBuilder::Impl::Impl() { @@ -471,6 +556,7 @@ DispatchHierarchicalSubTskGphBuilder::Impl::Impl() { same_2d_hierarchy_sub_tsk_gph_builder_.reset(new Same2DHierarchySubTskGphBuilder()); expand_to_same_2d_hierarchy_sub_tsk_gph_builder_.reset( new ExpandToSame2DHierarchySubTskGphBuilder()); + nd_nccl_send_recv_boxing_sub_tsk_gph_builder_.reset(new NDNcclSendRecvBoxingSubTskGphBuilder()); } DispatchHierarchicalSubTskGphBuilder::DispatchHierarchicalSubTskGphBuilder() { @@ -495,6 +581,14 @@ Maybe DispatchHierarchicalSubTskGphBuilder::Build( &reduced_out_nd_sbp); const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); + if ((in_hierarchy->NumAxes() > 2 || out_hierarchy->NumAxes() > 2) + && reduced_in_parallel_desc.device_type() == DeviceType::kCUDA + && reduced_out_parallel_desc.device_type() == DeviceType::kCUDA) { + return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( + ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, + reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, + time_shape); + } if (in_hierarchy->NumAxes() <= 2 && out_hierarchy->NumAxes() <= 2) { if (in_hierarchy->NumAxes() == 1 && out_hierarchy->NumAxes() == 1) { return impl_->flat_sub_tsk_gph_builder_->Build( @@ -513,6 +607,12 @@ Maybe DispatchHierarchicalSubTskGphBuilder::Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, time_shape); + } else if (reduced_in_parallel_desc.device_type() == DeviceType::kCUDA + && reduced_out_parallel_desc.device_type() == DeviceType::kCUDA) { + return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( + ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, + reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, + time_shape); } else { return Error::BoxingNotSupportedError(); } diff --git a/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp b/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp new file mode 100644 index 00000000000..95438c6d2b2 --- /dev/null +++ b/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp @@ -0,0 +1,92 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/to_string.h" +#include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h" + +namespace oneflow { + +void NcclSendRecvBoxingTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, + const Shape& logical_shape, const DataType& data_type, + const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, + const ParallelDesc& src_parallel_desc, + const ParallelDesc& dst_parallel_desc, + const int64_t parallel_id, const ParallelDesc& parallel_desc, + const bool has_input, const bool has_output) { + set_machine_id(machine_id); + set_thrd_id(thrd_id); + set_lbi(lbi); + logical_shape_ = logical_shape; + src_nd_sbp_ = src_nd_sbp; + dst_nd_sbp_ = dst_nd_sbp; + src_parallel_conf_ = src_parallel_desc.parallel_conf(); + dst_parallel_conf_ = dst_parallel_desc.parallel_conf(); + parallel_conf_ = parallel_desc.parallel_conf(); + parallel_ctx_.set_parallel_id(parallel_id); + parallel_ctx_.set_parallel_num(parallel_desc.parallel_num()); + has_input_ = has_input; + has_output_ = has_output; + data_type_ = data_type; +} + +void NcclSendRecvBoxingTaskNode::ProduceAllRegstsAndBindEdges() { + if (has_output_) { + std::shared_ptr out_regst = ProduceRegst("out", true, 1, 1); + this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); + } + ProduceRegst("tmp", true); +} + +void NcclSendRecvBoxingTaskNode::ConsumeAllRegsts() { + this->ForEachInDataEdge( + [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); +} + +void NcclSendRecvBoxingTaskNode::BuildExecGphAndRegst() { + ExecNode* node = mut_exec_gph().NewNode(); + OperatorConf op_conf; + op_conf.set_name("System-Nccl-Send-Recv-Boxing-" + NewUniqueId()); + op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); + auto* nccl_send_recv_boxing_conf = op_conf.mutable_nccl_send_recv_boxing_conf(); + *nccl_send_recv_boxing_conf->mutable_lbi() = lbi(); + logical_shape_.ToProto(nccl_send_recv_boxing_conf->mutable_logical_shape()); + nccl_send_recv_boxing_conf->set_data_type(data_type_); + *nccl_send_recv_boxing_conf->mutable_src_nd_sbp() = src_nd_sbp_; + *nccl_send_recv_boxing_conf->mutable_dst_nd_sbp() = dst_nd_sbp_; + *nccl_send_recv_boxing_conf->mutable_parallel_conf() = parallel_conf_; + *nccl_send_recv_boxing_conf->mutable_src_parallel_conf() = src_parallel_conf_; + *nccl_send_recv_boxing_conf->mutable_dst_parallel_conf() = dst_parallel_conf_; + nccl_send_recv_boxing_conf->set_has_input(has_input_); + nccl_send_recv_boxing_conf->set_has_output(has_output_); + std::shared_ptr sole_op = CHECK_JUST(ConstructOp(op_conf)); + node->mut_op() = sole_op; + if (has_input_) { node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); } + if (has_output_) { + std::shared_ptr out_regst = GetProducedRegst("out"); + out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); + node->BindBnWithRegst(sole_op->SoleObn(), out_regst); + } + node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); + node->InferBlobDescs(parallel_ctx()); +} + +void NcclSendRecvBoxingTaskNode::InferProducedDataRegstTimeShape() { + auto out_regst = GetProducedRegst("out"); + if (out_regst != nullptr) { out_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); } + auto tmp_regst = GetProducedRegst("tmp"); + tmp_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); +} + +} // namespace oneflow diff --git a/oneflow/core/graph/nccl_send_recv_boxing_task_node.h b/oneflow/core/graph/nccl_send_recv_boxing_task_node.h new file mode 100644 index 00000000000..fee688222ca --- /dev/null +++ b/oneflow/core/graph/nccl_send_recv_boxing_task_node.h @@ -0,0 +1,57 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ + +#include "oneflow/core/graph/transport_task_node.h" + +namespace oneflow { + +class NcclSendRecvBoxingTaskNode : public TransportTaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingTaskNode); + NcclSendRecvBoxingTaskNode() = default; + ~NcclSendRecvBoxingTaskNode() override = default; + + void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, + const Shape& logical_shape, const DataType& data_type, const NdSbp& src_nd_sbp, + const NdSbp& dst_nd_sbp, const ParallelDesc& src_parallel_desc, + const ParallelDesc& dst_parallel_desc, const int64_t parallel_id, + const ParallelDesc& parallel_desc, const bool has_input, const bool has_output); + TaskType GetTaskType() const override { return TaskType::kNcclSendRecvBoxing; } + const ParallelContext* parallel_ctx() const override { return ¶llel_ctx_; } + + private: + void BuildExecGphAndRegst() override; + void ProduceAllRegstsAndBindEdges() override; + void ConsumeAllRegsts() final; + void InferProducedDataRegstTimeShape() final; + + Shape logical_shape_; + DataType data_type_; + NdSbp src_nd_sbp_; + NdSbp dst_nd_sbp_; + ParallelConf src_parallel_conf_; + ParallelConf dst_parallel_conf_; + ParallelConf parallel_conf_; + ParallelContext parallel_ctx_; + bool has_input_; + bool has_output_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 5fd69c40274..040e113ad14 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -721,6 +721,12 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { const ParallelDesc& src_parallel_desc = src_op_node->parallel_desc(); const ParallelDesc& dst_parallel_desc = dst_op_node->parallel_desc(); const BlobDesc& blob_desc = src_op_node->LogicalBlobDesc4Lbi(lbi); + VLOG(3) << "src op: " << src_op_node->op().op_name() + << " dst op: " << dst_op_node->op().op_name() + << " src_parallel_conf: " << src_parallel_desc.parallel_conf().DebugString() + << " dst parallel conf: " << dst_parallel_desc.parallel_conf().DebugString() + << " src_nd_sbp " << src_nd_sbp.DebugString() << " dst nd_sbp " + << dst_nd_sbp.DebugString(); auto status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build( sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp, diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index e4df1c4a0db..2fb82cc1ab9 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -38,6 +38,7 @@ enum TaskType { kSspVariableProxy = 63; kBoxingZeros = 64; kCriticalSectionWaitTick = 65; + kNcclSendRecvBoxing = 66; }; message RegstDescIdSet { diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp new file mode 100644 index 00000000000..c573f9bf0ad --- /dev/null +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -0,0 +1,258 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/device/nccl_util.h" +#include "oneflow/core/job/eager_nccl_comm_manager.h" +#include "oneflow/core/register/tensor_slice_copier.h" +#include "oneflow/core/ep/include/primitive/memset.h" +#include "oneflow/core/ep/include/primitive/add.h" +#include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" + +#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 + +namespace oneflow { + +class NcclSendRecvBoxingKernel final : public Kernel { + public: + OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingKernel); + NcclSendRecvBoxingKernel() = default; + ~NcclSendRecvBoxingKernel() override = default; + + const std::vector>& in_tensor_slice_copier_vec() const { + return in_tensor_slice_copier_vec_; + } + const std::vector>& out_tensor_slice_copier_vec() const { + return out_tensor_slice_copier_vec_; + } + const std::vector& send_elem_cnts() const { return send_elem_cnts_; } + const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } + const bool has_input() const { return has_input_; } + const bool has_output() const { return has_output_; } + ncclComm_t comm() const { return GetOrCreate().comm; } + + private: + struct Comm { + Comm(ncclComm_t comm) : comm(comm) {} + ncclComm_t comm; + }; + + void Init() const { + ParallelDesc parallel_desc(parallel_conf_); + std::set> device_set; + for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global::Get()); + ncclComm_t comm; + if (has_independent_stream_) { + comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_); + } else { + comm = comm_mgr->GetCommForDevice(device_set); + } + comm_.reset(new Comm(comm)); + } + + const Comm& GetOrCreate() const { + if (!comm_) { Init(); } + return *comm_; + } + + void VirtualKernelInit(KernelContext* ctx) override; + void ForwardDataContent(KernelContext* ctx) const override; + + bool has_independent_stream_; + std::string stream_name_; + ParallelConf parallel_conf_; + mutable std::unique_ptr comm_; + bool src_nd_sbp_no_partial_parallel_; + std::vector> in_tensor_slice_copier_vec_; + std::vector> out_tensor_slice_copier_vec_; + std::vector send_elem_cnts_; + std::vector recv_elem_cnts_; + bool has_input_; + bool has_output_; +}; + +void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { + Blob* buf = ctx->BnInOp2Blob("buf"); + ncclComm_t comm = this->comm(); + cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); + const std::vector& send_elem_cnts = this->send_elem_cnts(); + const std::vector& recv_elem_cnts = this->recv_elem_cnts(); + const int64_t parallel_num = this->kernel_conf().parallel_ctx().parallel_num(); + const DataType data_type = buf->data_type(); + std::vector send_in_ptr; + std::vector recv_out_ptr; + char* buf_ptr = buf->mut_dptr(); + int64_t offset = 0; + if (this->has_input()) { + for (int64_t i = 0; i < parallel_num; ++i) { + void* send_ptr = reinterpret_cast(buf_ptr + offset); + send_in_ptr.push_back(send_ptr); + offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type); + } + } + if (this->has_output()) { + for (int64_t i = 0; i < parallel_num; ++i) { + void* recv_ptr = reinterpret_cast(buf_ptr + offset); + recv_out_ptr.push_back(recv_ptr); + offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type); + } + } + if (this->has_input()) { + const Blob* in = ctx->BnInOp2Blob("in"); + const std::vector>& in_tensor_slice_copier_vec = + this->in_tensor_slice_copier_vec(); + for (int64_t i = 0; i < parallel_num; ++i) { + if (in_tensor_slice_copier_vec.at(i)) { + in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr()); + } + } + } + const int64_t parallel_id = this->kernel_conf().parallel_ctx().parallel_id(); + OF_NCCL_CHECK(ncclGroupStart()); + for (int64_t i = 0; i < parallel_num; ++i) { + if (this->has_input() && send_elem_cnts.at(i) != 0) { + OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i, + comm, cuda_stream)); + } + if (this->has_output() && recv_elem_cnts.at(i) != 0) { + OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type), + i, comm, cuda_stream)); + } + } + OF_NCCL_CHECK(ncclGroupEnd()); + if (!this->has_output()) { return; } + Blob* out = ctx->BnInOp2Blob("out"); + const std::vector>& out_tensor_slice_copier_vec = + this->out_tensor_slice_copier_vec(); + + if (src_nd_sbp_no_partial_parallel_) { + for (int64_t i = 0; i < parallel_num; ++i) { + if (out_tensor_slice_copier_vec.at(i)) { + out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), recv_out_ptr.at(i)); + } + } + } else { + std::unique_ptr primitive = + ep::primitive::NewPrimitive(ctx->stream()->device_type(), + out->data_type()); + CHECK(primitive); + std::unique_ptr memset_primitive = + ep::primitive::NewPrimitive(ctx->stream()->device_type()); + CHECK(memset_primitive); + bool is_first_slice = true; + for (int64_t i = 0; i < parallel_num; ++i) { + if (out_tensor_slice_copier_vec.at(i)) { + if (is_first_slice) { + is_first_slice = false; + if (recv_elem_cnts.at(i) != out->shape().elem_cnt()) { + // if not same shape, memset out + memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, + out->shape().elem_cnt() * GetSizeOfDataType(data_type)); + } + out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), + recv_out_ptr.at(i)); + } else { + if (recv_elem_cnts.at(i) == out->shape().elem_cnt()) { + primitive->Launch(ctx->stream(), out->dptr(), recv_out_ptr.at(i), out->mut_dptr(), + out->shape().elem_cnt()); + } else { + void* out_buf = reinterpret_cast(buf_ptr + offset); + memset_primitive->Launch(ctx->stream(), out_buf, 0, + out->shape().elem_cnt() * GetSizeOfDataType(data_type)); + out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out_buf, recv_out_ptr.at(i)); + primitive->Launch(ctx->stream(), out->dptr(), out_buf, out->mut_dptr(), + out->shape().elem_cnt()); + } + } + } + } + } +} + +void NcclSendRecvBoxingKernel::VirtualKernelInit(KernelContext* ctx) { + const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); + has_independent_stream_ = this->op_conf().has_stream_name_hint(); + if (has_independent_stream_) { stream_name_ = this->op_conf().stream_name_hint(); } + parallel_conf_ = conf.parallel_conf(); + const int64_t parallel_id = this->kernel_conf().parallel_ctx().parallel_id(); + ParallelDesc parallel_desc(parallel_conf_); + ParallelDesc src_parallel_desc(conf.src_parallel_conf()); + ParallelDesc dst_parallel_desc(conf.dst_parallel_conf()); + const NdSbp& src_nd_sbp = conf.src_nd_sbp(); + const NdSbp& dst_nd_sbp = conf.dst_nd_sbp(); + has_input_ = conf.has_input(); + has_output_ = conf.has_output(); + src_nd_sbp_no_partial_parallel_ = !NdSbpHasPartialParallel(src_nd_sbp); + const DataType data_type = this->kernel_conf().data_type(); + const DeviceType device_type = parallel_desc.device_type(); + const Shape& logical_shape = Shape(conf.logical_shape()); + const int64_t parallel_num = parallel_desc.parallel_num(); + + std::vector src_send_intersections; + std::vector dst_recv_intersections; + GetRankSendRecvIntersection(parallel_id, parallel_desc, src_parallel_desc, dst_parallel_desc, + src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, + &dst_recv_intersections); + // if parallel_id exists in src parallel desc, has send + int64_t src_parallel_id = GetMappedParallelId(parallel_id, parallel_desc, src_parallel_desc); + if (src_parallel_id != -1) { + CHECK_EQ(src_send_intersections.size(), parallel_num); + send_elem_cnts_.resize(parallel_num); + in_tensor_slice_copier_vec_.resize(parallel_num); + const TensorSliceView& cur_rank_in_slice = GetTensorSliceView4ParallelId( + *src_parallel_desc.hierarchy(), src_nd_sbp, logical_shape, src_parallel_id); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = src_send_intersections.at(i); + if (!intersection.IsEmpty()) { + send_elem_cnts_.at(i) = intersection.shape().elem_cnt(); + in_tensor_slice_copier_vec_.at(i).reset( + new TensorSliceCopier(intersection, cur_rank_in_slice, data_type, device_type)); + } + } + } else { + CHECK_EQ(src_send_intersections.size(), 0); + } + + // if parallel_id exists in src parallel desc, has send + int64_t dst_parallel_id = GetMappedParallelId(parallel_id, parallel_desc, dst_parallel_desc); + if (dst_parallel_id != -1) { + CHECK_EQ(dst_recv_intersections.size(), parallel_num); + recv_elem_cnts_.resize(parallel_num); + out_tensor_slice_copier_vec_.resize(parallel_num); + const TensorSliceView& cur_rank_out_slice = GetTensorSliceView4ParallelId( + *dst_parallel_desc.hierarchy(), dst_nd_sbp, logical_shape, dst_parallel_id); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = dst_recv_intersections.at(i); + if (!intersection.IsEmpty()) { + recv_elem_cnts_.at(i) = intersection.shape().elem_cnt(); + out_tensor_slice_copier_vec_.at(i).reset( + new TensorSliceCopier(cur_rank_out_slice, intersection, data_type, device_type)); + } + } + } else { + CHECK_EQ(dst_recv_intersections.size(), 0); + } +} + +REGISTER_KERNEL(OperatorConf::kNcclSendRecvBoxingConf, NcclSendRecvBoxingKernel); + +} // namespace oneflow + +#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 diff --git a/oneflow/core/lazy/actor/naive_actor.cpp b/oneflow/core/lazy/actor/naive_actor.cpp index ac557618b74..59abdb3437b 100644 --- a/oneflow/core/lazy/actor/naive_actor.cpp +++ b/oneflow/core/lazy/actor/naive_actor.cpp @@ -34,6 +34,7 @@ REGISTER_ACTOR(TaskType::kSliceBoxing, NaiveActor); REGISTER_ACTOR(TaskType::kBoxingIdentity, NaiveActor); REGISTER_ACTOR(TaskType::kCollectiveBoxingPack, NaiveActor); REGISTER_ACTOR(TaskType::kCollectiveBoxingUnpack, NaiveActor); +REGISTER_ACTOR(TaskType::kNcclSendRecvBoxing, NaiveActor); REGISTER_ACTOR(TaskType::kDecodeH2D, NaiveActor); REGISTER_ACTOR(TaskType::kCriticalSectionWaitTick, NaiveActor); #ifdef WITH_CUDA diff --git a/oneflow/core/operator/nccl_send_recv_boxing_op.cpp b/oneflow/core/operator/nccl_send_recv_boxing_op.cpp new file mode 100644 index 00000000000..a2d8d3d02ec --- /dev/null +++ b/oneflow/core/operator/nccl_send_recv_boxing_op.cpp @@ -0,0 +1,133 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" + +namespace oneflow { + +class NcclSendRecvBoxingOp : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingOp); + NcclSendRecvBoxingOp() = default; + ~NcclSendRecvBoxingOp() override = default; + + Maybe InitFromOpConf() override; + Maybe InferInternalBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + UNIMPLEMENTED_THEN_RETURN(); + } + Maybe InferOutBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + LogicalBlobId lbi4ibn(const std::string& input_bn) const override; + LogicalBlobId lbi4obn(const std::string& output_bn) const override; +}; + +Maybe NcclSendRecvBoxingOp::InitFromOpConf() { + const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); + if (conf.has_input()) { EnrollInputBn("in", false); } + if (conf.has_output()) { EnrollOutputBn("out", false); } + EnrollTmpBn("buf"); + return Maybe::Ok(); +} + +Maybe NcclSendRecvBoxingOp::InferInternalBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const JobDesc* job_desc) const { + BlobDesc* buf = GetBlobDesc4BnInOp("buf"); + const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); + const NdSbp& src_nd_sbp = conf.src_nd_sbp(); + const NdSbp& dst_nd_sbp = conf.dst_nd_sbp(); + ParallelDesc parallel_desc(conf.parallel_conf()); + ParallelDesc in_parallel_desc(conf.src_parallel_conf()); + ParallelDesc out_parallel_desc(conf.dst_parallel_conf()); + const int64_t parallel_num = parallel_desc.parallel_num(); + const int64_t parallel_id = parallel_ctx->parallel_id(); + const Shape& logical_shape = Shape(conf.logical_shape()); + std::vector src_send_intersections; + std::vector dst_recv_intersections; + GetRankSendRecvIntersection(parallel_id, parallel_desc, in_parallel_desc, out_parallel_desc, + src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, + &dst_recv_intersections); + int64_t buf_count = 0; + if (conf.has_input()) { + const BlobDesc* in = GetBlobDesc4BnInOp("in"); + buf->set_data_type(in->data_type()); + CHECK_EQ(src_send_intersections.size(), parallel_num); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = src_send_intersections.at(i); + if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } + } + } + if (conf.has_output()) { + const BlobDesc* out = GetBlobDesc4BnInOp("out"); + buf->set_data_type(out->data_type()); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = dst_recv_intersections.at(i); + if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } + } + if (NdSbpHasPartialParallel(src_nd_sbp)) { + // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out. + buf_count += out->shape().elem_cnt(); + } + } + buf->mut_shape() = Shape({buf_count}); + return Maybe::Ok(); +} + +LogicalBlobId NcclSendRecvBoxingOp::lbi4ibn(const std::string& input_bn) const { + return this->op_conf().nccl_send_recv_boxing_conf().lbi(); +} + +LogicalBlobId NcclSendRecvBoxingOp::lbi4obn(const std::string& output_bn) const { + return this->op_conf().nccl_send_recv_boxing_conf().lbi(); +} + +Maybe NcclSendRecvBoxingOp::InferOutBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); + const Shape& logical_shape = Shape(conf.logical_shape()); + if (conf.has_input()) { + const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); + const NdSbp& src_nd_sbp = conf.src_nd_sbp(); + const ParallelDesc& src_parallel_desc = ParallelDesc(conf.src_parallel_conf()); + std::shared_ptr in_shape = JUST(GetPhysicalShape( + logical_shape, src_nd_sbp, src_parallel_desc, parallel_ctx->parallel_id())); + CHECK_EQ_OR_RETURN(*in_shape, in_blob_desc->shape()); + } + if (conf.has_output()) { + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + const NdSbp& dst_nd_sbp = conf.dst_nd_sbp(); + const ParallelDesc& dst_parallel_desc = ParallelDesc(conf.dst_parallel_conf()); + std::shared_ptr out_shape = JUST(GetPhysicalShape( + logical_shape, dst_nd_sbp, dst_parallel_desc, parallel_ctx->parallel_id())); + out_blob_desc->mut_shape() = *out_shape; + out_blob_desc->set_data_type(conf.data_type()); + } + return Maybe::Ok(); +} + +REGISTER_OP(OperatorConf::kNcclSendRecvBoxingConf, NcclSendRecvBoxingOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/nccl_send_recv_boxing_op_util.cpp b/oneflow/core/operator/nccl_send_recv_boxing_op_util.cpp new file mode 100644 index 00000000000..a0be3320256 --- /dev/null +++ b/oneflow/core/operator/nccl_send_recv_boxing_op_util.cpp @@ -0,0 +1,170 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" + +namespace oneflow { + +namespace { +// Go through all the ranks while transfer between two nd sbps with no PartialSum under the same +// placement. +// NOTE: We need to make sure no partial sums in the sbps of the producer and consumer. +void DfsTraverseRanks4NdSbp( + int32_t depth, std::vector& in_parallel_ids, + const std::vector& out_parallel_ids, const Shape& in_parallel_hierarchy, + const NdIndexOffsetHelper& in_hierarchy_index_helper, + const NdSbp& in_nd_sbp, const std::function& visit) { + if (depth >= in_parallel_hierarchy.NumAxes()) { + visit(in_hierarchy_index_helper.NdIndexToOffset(in_parallel_ids.data(), + in_parallel_hierarchy.NumAxes())); + return; + } + if (in_nd_sbp.sbp_parallel(depth).has_broadcast_parallel()) { + // If Broadcast in the sbp of the producer, only visit those ranks with the same id as the + // current rank along the depth-dimension. + in_parallel_ids[depth] = out_parallel_ids[depth]; + DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, in_parallel_hierarchy, + in_hierarchy_index_helper, in_nd_sbp, visit); + } else { + // If Split or PartialSum, go through all the ranks along the depth-dimension. + for (int64_t i = 0; i < in_parallel_hierarchy.dim_vec().at(depth); i++) { + in_parallel_ids[depth] = i; + DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, in_parallel_hierarchy, + in_hierarchy_index_helper, in_nd_sbp, visit); + } + } +} + +bool NdSbpNoPartialParallel(const NdSbp& nd_sbp) { + CHECK_GT(nd_sbp.sbp_parallel_size(), 0); + FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) { + if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return false; } + } + return true; +} + +} // namespace + +int64_t GetMappedParallelId(const int64_t from_parallel_id, const ParallelDesc& from_parallel_desc, + const ParallelDesc& to_parallel_desc) { + const int64_t machine_id = CHECK_JUST(from_parallel_desc.MachineId4ParallelId(from_parallel_id)); + const int64_t device_index = CHECK_JUST(from_parallel_desc.DeviceId4ParallelId(from_parallel_id)); + if (to_parallel_desc.Containing(machine_id, device_index)) { + return CHECK_JUST(to_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + } else { + return -1; + } +} + +void GetRankSendRecvIntersection(int64_t parallel_id, const ParallelDesc& parallel_desc, + const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, + const NdSbp& out_nd_sbp, const Shape& logical_shape, + std::vector* send_intersections, + std::vector* recv_intersections) { + const int64_t parallel_num = parallel_desc.parallel_num(); + CHECK_LT(parallel_id, parallel_num); + + const std::vector& in_slices = + GetTensorSliceView(*in_parallel_desc.hierarchy(), in_nd_sbp, logical_shape); + const std::vector& out_slices = + GetTensorSliceView(*out_parallel_desc.hierarchy(), out_nd_sbp, logical_shape); + + const auto& in_parallel_hierarchy = in_parallel_desc.hierarchy(); + int32_t in_hierarchy_dimension = in_parallel_hierarchy->NumAxes(); + const NdIndexOffsetHelper in_hierarchy_index_helper( + in_parallel_hierarchy->dim_vec().data(), in_hierarchy_dimension); + + const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + const int64_t device_index = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + const int64_t in_parallel_num = in_parallel_desc.parallel_num(); + const int64_t out_parallel_num = out_parallel_desc.parallel_num(); + // cur rank recv from + // cur rank has output + if (out_parallel_desc.Containing(machine_id, device_index)) { + recv_intersections->resize(parallel_num); + int64_t out_id = + CHECK_JUST(out_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + const TensorSliceView& cur_rank_out_slice = out_slices.at(out_id); + const auto& add_to_recv_intersections = [&](int32_t send_id) { + const TensorSliceView& in_slice = in_slices.at(send_id); + const TensorSliceView& intersection = cur_rank_out_slice.Intersect(in_slice); + if (intersection.IsEmpty()) { return; } + const int64_t merged_id = GetMappedParallelId(send_id, in_parallel_desc, parallel_desc); + recv_intersections->at(merged_id) = intersection; + }; + int64_t corresponding_in_id = 0; + // For example [[0, 1], [2, 3]] -> [[1, 3], [5, 6]] + if (in_parallel_desc.Containing(machine_id, device_index)) { + // 1 and 3 are in [[0, 1], [2, 3]], use the same id in the producer parallel description + // The id of 1 is (0, 1), the id of 3 is (1, 1) + corresponding_in_id = + CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + } else { + // 5 and 7 are not in [[0, 1], [2, 3]] + // Then the id does not matter + corresponding_in_id = out_id % in_parallel_num; + } + std::vector in_parallel_ids(in_hierarchy_dimension); + // The corresponding parallel id of a consumer rank in the producer parallel description + std::vector out_parallel_ids(in_hierarchy_dimension); + in_hierarchy_index_helper.OffsetToNdIndex(corresponding_in_id, out_parallel_ids.data(), + in_hierarchy_dimension); + DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *in_parallel_hierarchy, + in_hierarchy_index_helper, in_nd_sbp, add_to_recv_intersections); + } + + // cur rank send to + if (in_parallel_desc.Containing(machine_id, device_index)) { + send_intersections->resize(parallel_num); + int64_t in_id = + CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + const TensorSliceView& cur_rank_in_slice = in_slices.at(in_id); + for (int64_t recv_i = 0; recv_i < out_parallel_num; ++recv_i) { + const auto& add_to_send_intersections = [&](int32_t send_id) { + if (send_id != in_id) { return; } + const TensorSliceView& out_slice = out_slices.at(recv_i); + const TensorSliceView& intersection = out_slice.Intersect(cur_rank_in_slice); + if (intersection.IsEmpty()) { return; } + const int64_t merged_id = GetMappedParallelId(recv_i, out_parallel_desc, parallel_desc); + send_intersections->at(merged_id) = intersection; + }; + int64_t out_device_id = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(recv_i)); + int64_t out_machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(recv_i)); + int64_t corresponding_in_id = 0; + // For example [[0, 1], [2, 3]] -> [[1, 3], [5, 6]] + if (in_parallel_desc.Containing(out_machine_id, out_device_id)) { + // 1 and 3 are in [[0, 1], [2, 3]], use the same id in the producer parallel description + // The id of 1 is (0, 1), the id of 3 is (1, 1) + corresponding_in_id = + CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(out_machine_id, out_device_id)); + } else { + // 5 and 7 are not in [[0, 1], [2, 3]] + // Then the id does not matter + corresponding_in_id = recv_i % in_parallel_num; + } + std::vector in_parallel_ids(in_hierarchy_dimension); + // The corresponding parallel id of a consumer rank in the producer parallel description + std::vector out_parallel_ids(in_hierarchy_dimension); + in_hierarchy_index_helper.OffsetToNdIndex(corresponding_in_id, out_parallel_ids.data(), + in_hierarchy_dimension); + DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *in_parallel_hierarchy, + in_hierarchy_index_helper, in_nd_sbp, add_to_send_intersections); + } + } +} + +} // namespace oneflow diff --git a/oneflow/core/operator/nccl_send_recv_boxing_op_util.h b/oneflow/core/operator/nccl_send_recv_boxing_op_util.h new file mode 100644 index 00000000000..f491a50e91b --- /dev/null +++ b/oneflow/core/operator/nccl_send_recv_boxing_op_util.h @@ -0,0 +1,31 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/register/tensor_slice_view.h" +#include "oneflow/core/job/nd_sbp_util.h" + +namespace oneflow { + +int64_t GetMappedParallelId(const int64_t from_parallel_id, const ParallelDesc& from_parallel_desc, + const ParallelDesc& to_parallel_desc); + +void GetRankSendRecvIntersection(int64_t parallel_id, const ParallelDesc& parallel_desc, + const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, + const NdSbp& out_nd_sbp, const Shape& logical_shape, + std::vector* send_intersections, + std::vector* recv_intersections); + +} // namespace oneflow diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 4589ae3507e..cb6cc5d80a3 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -13,6 +13,7 @@ import "oneflow/core/job/sbp_parallel.proto"; import "oneflow/core/graph/boxing/collective_boxing.proto"; import "oneflow/core/job/initializer_conf.proto"; import "oneflow/core/job/regularizer_conf.proto"; +import "oneflow/core/job/placement.proto"; import "oneflow/core/job/learning_rate_schedule_conf.proto"; import "oneflow/core/operator/interface_blob_conf.proto"; import "oneflow/core/register/blob_desc.proto"; @@ -401,6 +402,19 @@ message BoxingZerosOpConf { required DataType data_type = 3; } +message NcclSendRecvBoxingOpConf { + required LogicalBlobId lbi = 1; + required NdSbp src_nd_sbp = 2; + required NdSbp dst_nd_sbp = 3; + required ParallelConf parallel_conf = 4; + required ParallelConf src_parallel_conf = 5; + required ParallelConf dst_parallel_conf = 6; + required ShapeProto logical_shape = 7; + required DataType data_type = 8; + required bool has_input = 9; + required bool has_output = 10; +} + message OperatorConf { required string name = 1; optional string device_tag = 4 [default = "invalid_device"]; @@ -446,6 +460,7 @@ message OperatorConf { CollectiveBoxingPackOpConf collective_boxing_pack_conf = 174; CollectiveBoxingUnpackOpConf collective_boxing_unpack_conf = 175; BoxingZerosOpConf boxing_zeros_conf = 176; + NcclSendRecvBoxingOpConf nccl_send_recv_boxing_conf = 177; UserOpConf user_conf = 199; // domain op diff --git a/python/oneflow/test/modules/test_nccl_send_recv_boxing.py b/python/oneflow/test/modules/test_nccl_send_recv_boxing.py new file mode 100644 index 00000000000..20c8d09f4ed --- /dev/null +++ b/python/oneflow/test/modules/test_nccl_send_recv_boxing.py @@ -0,0 +1,103 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import oneflow +import numpy as np +import oneflow as flow +import oneflow.unittest +from oneflow.test_utils.test_util import GenArgList + +import time +import os + +os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "1" + + +def _test_nccl_send_recv_boxing( + test_case, src_nd_sbp, dst_nd_sbp, src_ranks, dst_ranks +): + # can not process p in dst + if flow.sbp.partial_sum() in dst_nd_sbp: + return + # skip src == dst + if src_nd_sbp == dst_nd_sbp: + return + # in this case, use intra group boxing + if src_nd_sbp[0] == dst_nd_sbp[0]: + return + # in this case, use inter group boxing + if ( + src_nd_sbp[1] == dst_nd_sbp[1] + and src_nd_sbp[0] != src_nd_sbp[1] + and src_nd_sbp[0] != src_nd_sbp[1] + ): + return + # in this case, use 1d boxing + if src_nd_sbp[0] == src_nd_sbp[1] and dst_nd_sbp[0] == dst_nd_sbp[1]: + return + src_placement = flow.placement("cuda", ranks=src_ranks) + dst_placement = flow.placement("cuda", ranks=dst_ranks) + + class TestGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, x): + y = x.to_global(sbp=dst_nd_sbp, placement=dst_placement) + return y + + x = flow.tensor( + np.arange(12 * 16 * 16).reshape(12, 16, 16), + sbp=src_nd_sbp, + placement=src_placement, + ) + graph = TestGraph() + y = graph(x) + test_case.assertTrue(np.array_equal(y.numpy(), x.numpy())) + + +def gen_nd_sbp(): + sbp_list = [ + flow.sbp.partial_sum(), + flow.sbp.broadcast(), + flow.sbp.split(0), + flow.sbp.split(1), + flow.sbp.split(2), + ] + nd_sbp_list = [] + for sbp0 in sbp_list: + for sbp1 in sbp_list: + nd_sbp_list.append([sbp0, sbp1]) + return nd_sbp_list + + +@flow.unittest.skip_unless_1n4d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestNcclSendRecvBoxing(flow.unittest.TestCase): + def test_nccl_send_recv_boxing(test_case): + arg_dict = OrderedDict() + arg_dict["src_nd_sbp"] = gen_nd_sbp() + arg_dict["dst_nd_sbp"] = gen_nd_sbp() + arg_dict["src_ranks"] = [[[0, 1], [2, 3]], [[0, 1]]] + arg_dict["dst_ranks"] = [[[0, 1], [2, 3]], [[2, 3]]] + for arg in GenArgList(arg_dict): + _test_nccl_send_recv_boxing(test_case, *arg) + + +if __name__ == "__main__": + unittest.main()