diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index 0c33ca7c25c326..2a6aea1c7a13af 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -18,15 +18,21 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif +#include "paddle/fluid/distributed/collective/utils.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" + namespace paddle { namespace operators { template class PartialRecvOpCUDAKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { #if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \ NCCL_VERSION_CODE >= 2703 auto out = ctx.Output("Out"); @@ -74,35 +80,86 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { auto map = distributed::ProcessGroupMapFromGid::getInstance(); if (map->has(rid)) { // Use ProcessGroup - distributed::ProcessGroup *pg = map->get(rid); + distributed::ProcessGroup* pg = map->get(rid); auto task = pg->Recv(out, peer, offset, recv_numel, /*sync_op*/ true); task->Wait(); } else { gpuStream_t stream = nullptr; - auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + platform::NCCLComm* comm = nullptr; + phi::distributed::NCCLCommContext* comm_ctx = nullptr; + + int nranks = 0; + int rank = 0; + + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + if (FLAGS_dynamic_static_unified_comm) { + // Use New Communication Library + PADDLE_ENFORCE_EQ( + comm_context_manager.Has(std::to_string(rid)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + + stream = comm_ctx->GetStream(); + nranks = comm_ctx->GetSize(); + rank = comm_ctx->GetRank(); + + VLOG(3) << "new comm_context_manager has ring_id " << rid; + } else { + comm = platform::NCCLCommContext::Instance().Get(rid, place); + + stream = comm->stream(); + nranks = comm->nranks(); + rank = comm->rank(); + + VLOG(3) << "old NCCLCommContext has ring_id" << rid; + } + if (ctx.Attr("use_calc_stream")) { // should ExecutionContext for calc stream. stream = ctx.cuda_device_context().stream(); - } else { - stream = comm->stream(); } + PADDLE_ENFORCE_LT(peer, - comm->nranks(), + nranks, platform::errors::InvalidArgument( "The value of peer (%d) you set must " - "be less than comm->nranks (%d).", + "be less than nranks (%d).", peer, - comm->nranks())); + nranks)); + ncclDataType_t dtype = platform::ToNCCLDataType(type); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclRecv(out->data() + offset, - recv_numel, - dtype, - peer, - comm->comm(), - stream)); - VLOG(3) << "rank " << comm->rank() << " recv " << recv_numel - << " from offset[" << offset << "] from " << peer; + + if (comm_ctx) { + auto recv_buf = distributed::GetPartialTensor(*out, offset, recv_numel); + + comm_ctx->Recv(&recv_buf, recv_numel, peer, stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclRecv(out->data() + offset, + recv_numel, + dtype, + peer, + comm->comm(), + stream)); + } + VLOG(3) << "rank " << rank << " recv " << recv_numel << " from offset[" + << offset << "] from " << peer; } #else PADDLE_THROW(platform::errors::Unavailable(