diff --git a/src/collective/communicator.cu b/src/collective/communicator.cu index 8cdb7f2fd3bd..915a3becab9e 100644 --- a/src/collective/communicator.cu +++ b/src/collective/communicator.cu @@ -29,10 +29,18 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) { old_device_ordinal = device_ordinal; old_world_size = communicator_->GetWorldSize(); #ifdef XGBOOST_USE_NCCL - if (type_ != CommunicatorType::kFederated) { - device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal)); - } else { - device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); + switch (type_) { + case CommunicatorType::kRabit: + device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false)); + break; + case CommunicatorType::kFederated: + device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); + break; + case CommunicatorType::kInMemory: + device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true)); + break; + default: + device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false)); } #else device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu index 7f56860757cc..470700d2d36e 100644 --- a/src/collective/nccl_device_communicator.cu +++ b/src/collective/nccl_device_communicator.cu @@ -7,8 +7,11 @@ namespace xgboost { namespace collective { -NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal) - : device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} { +NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync) + : device_ordinal_{device_ordinal}, + needs_sync_{needs_sync}, + world_size_{GetWorldSize()}, + rank_{GetRank()} { if (device_ordinal_ < 0) { LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; } @@ -140,6 +143,9 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si // First gather data from all the workers. dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type), nccl_comm_, cuda_stream_)); + if (needs_sync_) { + dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); + } // Then reduce locally. auto *out_buffer = static_cast(send_receive_buffer); diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index 925603d21252..bb3fce45c0ff 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -12,7 +12,20 @@ namespace collective { class NcclDeviceCommunicator : public DeviceCommunicator { public: - explicit NcclDeviceCommunicator(int device_ordinal); + /** + * @brief Construct a new NCCL communicator. + * @param device_ordinal The GPU device id. + * @param needs_sync Whether extra CUDA stream synchronization is needed. + * + * In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes + * a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization + * makes sure that the NCCL kernels are caught up, thus avoiding the deadlock. + * + * The Rabit communicator runs with one process per GPU, so the additional synchronization is not + * needed. The in-memory communicator is used in tests with multiple threads, each thread + * representing a rank/worker, so the additional synchronization is needed to avoid deadlocks. + */ + explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync); ~NcclDeviceCommunicator() override; void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override; @@ -60,6 +73,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator { Operation op); int const device_ordinal_; + bool const needs_sync_; int const world_size_; int const rank_; ncclComm_t nccl_comm_{}; diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu index 81dd3d46db0d..cd9cd26de184 100644 --- a/tests/cpp/collective/test_nccl_device_communicator.cu +++ b/tests/cpp/collective/test_nccl_device_communicator.cu @@ -16,7 +16,7 @@ namespace xgboost { namespace collective { TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) { - auto construct = []() { NcclDeviceCommunicator comm{-1}; }; + auto construct = []() { NcclDeviceCommunicator comm{-1, false}; }; EXPECT_THROW(construct(), dmlc::Error); }