From e70810be8a5d2b2263b18b1d5df17b335243c086 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 13 Jun 2023 12:53:03 -0700 Subject: [PATCH] Refactor device communicator to make allreduce more flexible (#9295) --- src/collective/communicator-inl.cuh | 81 ++++++++++++++++++ src/collective/device_communicator.cuh | 29 ++----- .../device_communicator_adapter.cuh | 38 +++------ src/collective/nccl_device_communicator.cuh | 84 ++++++++++++++----- src/common/quantile.cu | 13 ++- src/metric/auc.cu | 5 +- src/tree/fit_stump.cu | 6 +- src/tree/updater_gpu_hist.cu | 28 +++---- .../test_nccl_device_communicator.cu | 1 + tests/cpp/common/test_quantile.cu | 9 +- tests/cpp/plugin/test_federated_adapter.cu | 2 +- 11 files changed, 190 insertions(+), 106 deletions(-) create mode 100644 src/collective/communicator-inl.cuh diff --git a/src/collective/communicator-inl.cuh b/src/collective/communicator-inl.cuh new file mode 100644 index 000000000000..0c5fcf910e98 --- /dev/null +++ b/src/collective/communicator-inl.cuh @@ -0,0 +1,81 @@ +/** + * Copyright 2023 by XGBoost contributors + */ +#pragma once +#include +#include + +#include "communicator.h" +#include "device_communicator.cuh" + +namespace xgboost { +namespace collective { + +/** + * @brief Reduce values from all processes and distribute the result back to all processes. + * @param device ID of the device. + * @param send_receive_buffer Buffer storing the data. + * @param count Number of elements in the buffer. + */ +template +inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) { + Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op); +} + +template +inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) { + Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op); +} + +template +inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) { + Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op); +} + +template +inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) { + Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op); +} + +template +inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) { + Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op); +} + +template +inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) { + Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); +} + +template +inline void AllReduce(int device, float *send_receive_buffer, size_t count) { + Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op); +} + +template +inline void AllReduce(int device, double *send_receive_buffer, size_t count) { + Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op); +} + +/** + * @brief Gather variable-length values from all processes. + * @param device ID of the device. + * @param send_buffer Buffer storing the input data. + * @param length_bytes Length in bytes of the input data. + * @param segments Size of each segment. + * @param receive_buffer Buffer storing the output data. + */ +inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *receive_buffer) { + Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer); +} + +/** + * @brief Synchronize device operations. + * @param device ID of the device. + */ +inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); } + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/device_communicator.cuh b/src/collective/device_communicator.cuh index 32d69e1b52c1..a598918637f1 100644 --- a/src/collective/device_communicator.cuh +++ b/src/collective/device_communicator.cuh @@ -17,32 +17,15 @@ class DeviceCommunicator { virtual ~DeviceCommunicator() = default; /** - * @brief Sum values from all processes and distribute the result back to all processes. + * @brief Combines values from all processes and distributes the result back to all processes. + * * @param send_receive_buffer Buffer storing the data. * @param count Number of elements in the buffer. + * @param data_type Data type stored in the buffer. + * @param op The operation to perform. */ - virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0; - - /** - * @brief Sum values from all processes and distribute the result back to all processes. - * @param send_receive_buffer Buffer storing the data. - * @param count Number of elements in the buffer. - */ - virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0; - - /** - * @brief Sum values from all processes and distribute the result back to all processes. - * @param send_receive_buffer Buffer storing the data. - * @param count Number of elements in the buffer. - */ - virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0; - - /** - * @brief Sum values from all processes and distribute the result back to all processes. - * @param send_receive_buffer Buffer storing the data. - * @param count Number of elements in the buffer. - */ - virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0; + virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) = 0; /** * @brief Gather variable-length values from all processes. diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh index ae3b3f581d72..06637c5b4768 100644 --- a/src/collective/device_communicator_adapter.cuh +++ b/src/collective/device_communicator_adapter.cuh @@ -23,20 +23,18 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { ~DeviceCommunicatorAdapter() override = default; - void AllReduceSum(float *send_receive_buffer, size_t count) override { - DoAllReduceSum(send_receive_buffer, count); - } - - void AllReduceSum(double *send_receive_buffer, size_t count) override { - DoAllReduceSum(send_receive_buffer, count); - } - - void AllReduceSum(int64_t *send_receive_buffer, size_t count) override { - DoAllReduceSum(send_receive_buffer, count); - } + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) override { + if (communicator_->GetWorldSize() == 1) { + return; + } - void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override { - DoAllReduceSum(send_receive_buffer, count); + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + auto size = count * GetTypeSize(data_type); + host_buffer_.reserve(size); + dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); + communicator_->AllReduce(host_buffer_.data(), count, data_type, op); + dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); } void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, @@ -77,20 +75,6 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { } private: - template - void DoAllReduceSum(T *send_receive_buffer, size_t count) { - if (communicator_->GetWorldSize() == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto size = count * sizeof(T); - host_buffer_.reserve(size); - dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); - communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum); - dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); - } - int const device_ordinal_; Communicator *communicator_; /// Host buffer used to call communicator functions. diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index e14a2e446ed4..4e58fc5bac87 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -72,20 +72,18 @@ class NcclDeviceCommunicator : public DeviceCommunicator { } } - void AllReduceSum(float *send_receive_buffer, size_t count) override { - DoAllReduceSum(send_receive_buffer, count); - } - - void AllReduceSum(double *send_receive_buffer, size_t count) override { - DoAllReduceSum(send_receive_buffer, count); - } - - void AllReduceSum(int64_t *send_receive_buffer, size_t count) override { - DoAllReduceSum(send_receive_buffer, count); - } + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) override { + if (communicator_->GetWorldSize() == 1) { + return; + } - void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override { - DoAllReduceSum(send_receive_buffer, count); + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, + GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_, + cuda_stream_)); + allreduce_bytes_ += count * GetTypeSize(data_type); + allreduce_calls_ += 1; } void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, @@ -162,17 +160,59 @@ class NcclDeviceCommunicator : public DeviceCommunicator { return id; } - template - void DoAllReduceSum(T *send_receive_buffer, size_t count) { - if (communicator_->GetWorldSize() == 1) { - return; + static ncclDataType_t GetNcclDataType(DataType const &data_type) { + ncclDataType_t result; + switch (data_type) { + case DataType::kInt8: + result = ncclInt8; + break; + case DataType::kUInt8: + result = ncclUint8; + break; + case DataType::kInt32: + result = ncclInt32; + break; + case DataType::kUInt32: + result = ncclUint32; + break; + case DataType::kInt64: + result = ncclInt64; + break; + case DataType::kUInt64: + result = ncclUint64; + break; + case DataType::kFloat: + result = ncclFloat; + break; + case DataType::kDouble: + result = ncclDouble; + break; + default: + LOG(FATAL) << "Unknown data type."; } + return result; + } - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum, - nccl_comm_, cuda_stream_)); - allreduce_bytes_ += count * sizeof(T); - allreduce_calls_ += 1; + static ncclRedOp_t GetNcclRedOp(Operation const &op) { + ncclRedOp_t result; + switch (op) { + case Operation::kMax: + result = ncclMax; + break; + case Operation::kMin: + result = ncclMin; + break; + case Operation::kSum: + result = ncclSum; + break; + case Operation::kBitwiseAND: + case Operation::kBitwiseOR: + case Operation::kBitwiseXOR: + LOG(FATAL) << "Not implemented yet."; + default: + LOG(FATAL) << "Unknown reduce operation."; + } + return result; } int const device_ordinal_; diff --git a/src/common/quantile.cu b/src/common/quantile.cu index cabdc603b97e..5c81ec2ea06f 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -12,8 +12,7 @@ #include #include -#include "../collective/communicator.h" -#include "../collective/device_communicator.cuh" +#include "../collective/communicator-inl.cuh" #include "categorical.h" #include "common.h" #include "device_helpers.cuh" @@ -510,7 +509,6 @@ void SketchContainer::AllReduce() { } timer_.Start(__func__); - auto* communicator = collective::Communicator::GetDevice(device_); // Reduce the overhead on syncing. size_t global_sum_rows = num_rows_; collective::Allreduce(&global_sum_rows, 1); @@ -531,14 +529,15 @@ void SketchContainer::AllReduce() { auto offset = rank * d_columns_ptr.size(); thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(), gathered_ptrs.begin() + offset); - communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size()); + collective::AllReduce(device_, gathered_ptrs.data().get(), + gathered_ptrs.size()); // Get the data from all workers. std::vector recv_lengths; dh::caching_device_vector recvbuf; - communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(), - &recv_lengths, &recvbuf); - communicator->Synchronize(); + collective::AllGatherV(device_, this->Current().data().get(), + dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf); + collective::Synchronize(device_); // Segment the received data. auto s_recvbuf = dh::ToSpan(recvbuf); diff --git a/src/metric/auc.cu b/src/metric/auc.cu index fdbf0501ac6b..6e3032e4297a 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -11,7 +11,7 @@ #include #include -#include "../collective/device_communicator.cuh" +#include "../collective/communicator-inl.cuh" #include "../common/algorithm.cuh" // SegmentedArgSort #include "../common/optional_weight.h" // OptionalWeights #include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads @@ -205,8 +205,7 @@ double ScaleClasses(common::Span results, common::Span local_are if (collective::IsDistributed()) { int32_t device = dh::CurrentDevice(); CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device); - auto* communicator = collective::Communicator::GetDevice(device); - communicator->AllReduceSum(results.data(), results.size()); + collective::AllReduce(device, results.data(), results.size()); } auto reduce_in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { diff --git a/src/tree/fit_stump.cu b/src/tree/fit_stump.cu index 58a1fae82987..33f92014e0ff 100644 --- a/src/tree/fit_stump.cu +++ b/src/tree/fit_stump.cu @@ -11,7 +11,7 @@ #include // std::size_t -#include "../collective/device_communicator.cuh" // DeviceCommunicator +#include "../collective/communicator-inl.cuh" #include "../common/device_helpers.cuh" // dh::MakeTransformIterator #include "fit_stump.h" #include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE @@ -49,8 +49,8 @@ void FitStump(Context const* ctx, linalg::TensorView gpai thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it, thrust::make_discard_iterator(), dh::tbegin(d_sum.Values())); - collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(ctx->gpu_id); - communicator->AllReduceSum(reinterpret_cast(d_sum.Values().data()), d_sum.Size() * 2); + collective::AllReduce( + ctx->gpu_id, reinterpret_cast(d_sum.Values().data()), d_sum.Size() * 2); thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets, [=] XGBOOST_DEVICE(std::size_t i) mutable { diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index d1c1c829098d..5e5d2b5cb97c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -12,7 +12,7 @@ #include #include -#include "../collective/device_communicator.cuh" +#include "../collective/communicator-inl.cuh" #include "../common/bitfield.h" #include "../common/categorical.h" #include "../common/cuda_context.cuh" // CUDAContext @@ -546,12 +546,13 @@ struct GPUHistMakerDevice { } // num histograms is the number of contiguous histograms in memory to reduce over - void AllReduceHist(int nidx, collective::DeviceCommunicator* communicator, int num_histograms) { + void AllReduceHist(int nidx, int num_histograms) { monitor.Start("AllReduce"); auto d_node_hist = hist.GetNodeHistogram(nidx).data(); using ReduceT = typename std::remove_pointer::type::ValueT; - communicator->AllReduceSum(reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * 2 * num_histograms); + collective::AllReduce( + ctx_->gpu_id, reinterpret_cast(d_node_hist), + page->Cuts().TotalBins() * 2 * num_histograms); monitor.Stop("AllReduce"); } @@ -559,8 +560,7 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(std::vector const& candidates, - collective::DeviceCommunicator* communicator, const RegTree& tree) { + void BuildHistLeftRight(std::vector const& candidates, const RegTree& tree) { if (candidates.empty()) return; // Some nodes we will manually compute histograms // others we will do by subtraction @@ -591,7 +591,7 @@ struct GPUHistMakerDevice { // Reduce all in one go // This gives much better latency in a distributed setting // when processing a large batch - this->AllReduceHist(hist_nidx.at(0), communicator, hist_nidx.size()); + this->AllReduceHist(hist_nidx.at(0), hist_nidx.size()); for (size_t i = 0; i < subtraction_nidx.size(); i++) { auto build_hist_nidx = hist_nidx.at(i); @@ -601,7 +601,7 @@ struct GPUHistMakerDevice { if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { // Calculate other histogram manually this->BuildHist(subtraction_trick_nidx); - this->AllReduceHist(subtraction_trick_nidx, communicator, 1); + this->AllReduceHist(subtraction_trick_nidx, 1); } } } @@ -659,7 +659,7 @@ struct GPUHistMakerDevice { parent.RightChild()); } - GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) { + GPUExpandEntry InitRoot(RegTree* p_tree) { constexpr bst_node_t kRootNIdx = 0; dh::XGBCachingDeviceAllocator alloc; auto quantiser = *this->quantiser; @@ -676,7 +676,7 @@ struct GPUHistMakerDevice { hist.AllocateHistograms({kRootNIdx}); this->BuildHist(kRootNIdx); - this->AllReduceHist(kRootNIdx, communicator, 1); + this->AllReduceHist(kRootNIdx, 1); // Remember root stats auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised); @@ -692,7 +692,6 @@ struct GPUHistMakerDevice { void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, ObjInfo const* task, RegTree* p_tree, - collective::DeviceCommunicator* communicator, HostDeviceVector* p_out_position) { auto& tree = *p_tree; // Process maximum 32 nodes at a time @@ -703,7 +702,7 @@ struct GPUHistMakerDevice { monitor.Stop("Reset"); monitor.Start("InitRoot"); - driver.Push({ this->InitRoot(p_tree, communicator) }); + driver.Push({this->InitRoot(p_tree)}); monitor.Stop("InitRoot"); // The set of leaves that can be expanded asynchronously @@ -730,7 +729,7 @@ struct GPUHistMakerDevice { monitor.Stop("UpdatePosition"); monitor.Start("BuildHist"); - this->BuildHistLeftRight(filtered_expand_set, communicator, tree); + this->BuildHistLeftRight(filtered_expand_set, tree); monitor.Stop("BuildHist"); monitor.Start("EvaluateSplits"); @@ -851,8 +850,7 @@ class GPUHistMaker : public TreeUpdater { monitor_.Stop("InitData"); gpair->SetDevice(ctx_->gpu_id); - auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id); - maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position); + maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position); } bool UpdatePredictionCache(const DMatrix* data, diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu index 8ce877aef98c..6d3203522dae 100644 --- a/tests/cpp/collective/test_nccl_device_communicator.cu +++ b/tests/cpp/collective/test_nccl_device_communicator.cu @@ -8,6 +8,7 @@ #include // for string #include "../../../src/collective/nccl_device_communicator.cuh" +#include "../../../src/collective/communicator-inl.cuh" namespace xgboost { namespace collective { diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 3a8e6d046267..935d88ab6352 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -1,7 +1,7 @@ #include #include "test_quantile.h" #include "../helpers.h" -#include "../../../src/collective/device_communicator.cuh" +#include "../../../src/collective/communicator-inl.cuh" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/quantile.cuh" @@ -464,10 +464,9 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) { thrust::copy(thrust::device, local_data.data(), local_data.data() + local_data.size(), all_workers.begin() + local_data.size() * rank); - collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(device); - - communicator->AllReduceSum(all_workers.data().get(), all_workers.size()); - communicator->Synchronize(); + collective::AllReduce(device, all_workers.data().get(), + all_workers.size()); + collective::Synchronize(device); auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float); std::vector h_base_line(base_line.size()); diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu index a5e901f26f47..3fb793fa7160 100644 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -36,7 +36,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { int count = 3; thrust::device_vector buffer(count, 0); thrust::sequence(buffer.begin(), buffer.end()); - adapter.AllReduceSum(buffer.data().get(), count); + adapter.AllReduce(buffer.data().get(), count, DataType::kDouble, Operation::kSum); thrust::host_vector host_buffer = buffer; EXPECT_EQ(host_buffer.size(), count); for (auto i = 0; i < count; i++) {