From a2d2d0c0cf63a3a2fd58a382fae8ee3e3ea16b04 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Wed, 20 Sep 2023 19:03:45 +0800 Subject: [PATCH 1/4] [NewComm] No.2 compatiable upgrade for partial_recv op --- .../collective/partial_recv_op.cu.cc | 92 ++++++++++++++----- 1 file changed, 69 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index 0c33ca7c25c326..7a870c04c39191 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -18,15 +18,20 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif +#include "paddle/phi/core/distributed/comm_context_manager.h" + namespace paddle { namespace operators { template class PartialRecvOpCUDAKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { #if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \ NCCL_VERSION_CODE >= 2703 auto out = ctx.Output("Out"); @@ -74,41 +79,82 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { auto map = distributed::ProcessGroupMapFromGid::getInstance(); if (map->has(rid)) { // Use ProcessGroup - distributed::ProcessGroup *pg = map->get(rid); + distributed::ProcessGroup* pg = map->get(rid); auto task = pg->Recv(out, peer, offset, recv_numel, /*sync_op*/ true); task->Wait(); } else { gpuStream_t stream = nullptr; - auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + platform::NCCLComm* comm = nullptr; + phi::distributed::NCCLCommContext* comm_ctx = nullptr; + + int nranks = 0; + int rank = 0; + + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + if (FLAGS_dynamic_static_unified_comm) { + // Use New Communication Libaray + PADDLE_ENFORCE_EQ( + comm_context_manager.Has(std::to_string(rid)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + + stream = comm_ctx->GetStream(); + nranks = comm_ctx->GetSize(); + rank = comm_ctx->GetRank(); + + VLOG(3) << "new comm_context_manager has ring_id " << rid; + } else { + comm = platform::NCCLCommContext::Instance().Get(rid, place); + + stream = comm->stream(); + nranks = comm->nranks(); + rank = comm->rank(); + + VLOG(3) << "old NCCLCommContext has ring_id" << rid; + } if (ctx.Attr("use_calc_stream")) { // should ExecutionContext for calc stream. stream = ctx.cuda_device_context().stream(); - } else { - stream = comm->stream(); + PADDLE_ENFORCE_LT(peer, + nranks, + platform::errors::InvalidArgument( + "The value of peer (%d) you set must " + "be less than nranks (%d).", + peer, + nranks)); + ncclDataType_t dtype = platform::ToNCCLDataType(type); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclRecv(out->data() + offset, + recv_numel, + dtype, + peer, + comm->comm(), + stream)); + VLOG(3) << "rank " << rank << " recv " << recv_numel << " from offset[" + << offset << "] from " << peer; } - PADDLE_ENFORCE_LT(peer, - comm->nranks(), - platform::errors::InvalidArgument( - "The value of peer (%d) you set must " - "be less than comm->nranks (%d).", - peer, - comm->nranks())); - ncclDataType_t dtype = platform::ToNCCLDataType(type); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclRecv(out->data() + offset, - recv_numel, - dtype, - peer, - comm->comm(), - stream)); - VLOG(3) << "rank " << comm->rank() << " recv " << recv_numel - << " from offset[" << offset << "] from " << peer; - } #else PADDLE_THROW(platform::errors::Unavailable( "PaddlePaddle should be compiled with NCCL and " "NCCL version >= 2.7.3 is needed.")); #endif + } } }; From 98d44441b236e5e35154ef514ec3474478551697 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Wed, 20 Sep 2023 22:03:24 +0800 Subject: [PATCH 2/4] fix --- .../collective/partial_recv_op.cu.cc | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index 7a870c04c39191..584eed393a7ecc 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -128,17 +128,27 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { VLOG(3) << "old NCCLCommContext has ring_id" << rid; } + if (ctx.Attr("use_calc_stream")) { // should ExecutionContext for calc stream. stream = ctx.cuda_device_context().stream(); - PADDLE_ENFORCE_LT(peer, - nranks, - platform::errors::InvalidArgument( - "The value of peer (%d) you set must " - "be less than nranks (%d).", - peer, - nranks)); - ncclDataType_t dtype = platform::ToNCCLDataType(type); + } + + PADDLE_ENFORCE_LT(peer, + nranks, + platform::errors::InvalidArgument( + "The value of peer (%d) you set must " + "be less than nranks (%d).", + peer, + nranks)); + + ncclDataType_t dtype = platform::ToNCCLDataType(type); + + if (comm_ctx) { + auto recv_buf = distributed::GetPartialTensor(*out, offset, recv_numel); + + comm_ctx->Recv(&recv_buf, recv_numel, peer, stream); + } else { PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::ncclRecv(out->data() + offset, recv_numel, @@ -146,15 +156,15 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { peer, comm->comm(), stream)); - VLOG(3) << "rank " << rank << " recv " << recv_numel << " from offset[" - << offset << "] from " << peer; } + VLOG(3) << "rank " << rank << " recv " << recv_numel << " from offset[" + << offset << "] from " << peer; + } #else PADDLE_THROW(platform::errors::Unavailable( "PaddlePaddle should be compiled with NCCL and " "NCCL version >= 2.7.3 is needed.")); #endif - } } }; From a6d0f0459acf1b137943849e7041d1a6c64284b7 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Wed, 20 Sep 2023 22:41:59 +0800 Subject: [PATCH 3/4] add header --- paddle/fluid/operators/collective/partial_recv_op.cu.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index 584eed393a7ecc..baeee38dd70f25 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -23,6 +23,7 @@ limitations under the License. */ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif +#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" namespace paddle { From 74093c871a25a3c0d74e1897e1b11e33c636342f Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Wed, 20 Sep 2023 22:51:53 +0800 Subject: [PATCH 4/4] fix typo --- paddle/fluid/operators/collective/partial_recv_op.cu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index baeee38dd70f25..2a6aea1c7a13af 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -95,7 +95,7 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { phi::distributed::CommContextManager::GetInstance(); if (FLAGS_dynamic_static_unified_comm) { - // Use New Communication Libaray + // Use New Communication Library PADDLE_ENFORCE_EQ( comm_context_manager.Has(std::to_string(rid)), true,