diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 99241249ffcb..69cdd09a3e64 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -104,6 +104,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/comm_group.o \ $(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/tracker.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index fc2cd3b9f2a3..b34d8c64908b 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -104,6 +104,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/comm_group.o \ $(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/tracker.o \ diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index 8a5b31c36546..bc652f2e8cde 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -165,7 +165,7 @@ template T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) { std::array results{dividend, divisor}; auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size())); - collective::SafeColl(rc); + SafeColl(rc); std::tie(dividend, divisor) = std::tuple_cat(results); if (divisor <= 0) { return std::numeric_limits::quiet_NaN(); diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index 446db73b560a..5d1ec664e656 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -33,6 +33,7 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size bool is_last_segment = send_rank == (world - 1); auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size; auto send_seg = data.subspan(send_off, send_nbytes); + CHECK_NE(send_seg.size(), 0); return next_ch->SendAll(send_seg.data(), send_seg.size_bytes()); } << [&] { auto recv_rank = (rank + world - r - 1 + worker_off) % world; @@ -40,9 +41,10 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size bool is_last_segment = recv_rank == (world - 1); auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size; auto recv_seg = data.subspan(recv_off, recv_nbytes); + CHECK_NE(recv_seg.size(), 0); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); } << [&] { - return prev_ch->Block(); + return comm.Block(); }; if (!rc.OK()) { return rc; @@ -106,4 +108,47 @@ namespace detail { return comm.Block(); } } // namespace detail + +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, CommGroup const& comm, std::vector> const& input) { + auto n_inputs = input.size(); + std::vector sizes(n_inputs); + std::transform(input.cbegin(), input.cend(), sizes.begin(), + [](auto const& vec) { return vec.size(); }); + + std::vector recv_segments(comm.World() + 1, 0); + + HostDeviceVector recv; + auto rc = + AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv); + SafeColl(rc); + + auto global_sizes = common::RestoreType(recv.ConstHostSpan()); + std::vector offset(global_sizes.size() + 1); + offset[0] = 0; + for (std::size_t i = 1; i < offset.size(); i++) { + offset[i] = offset[i - 1] + global_sizes[i - 1]; + } + + std::vector collected; + for (auto const& vec : input) { + collected.insert(collected.end(), vec.cbegin(), vec.cend()); + } + rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments, + &recv); + SafeColl(rc); + auto out = common::RestoreType(recv.ConstHostSpan()); + + std::vector> result; + for (std::size_t i = 1; i < offset.size(); ++i) { + std::vector local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]); + result.emplace_back(std::move(local)); + } + return result; +} + +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, std::vector> const& input) { + return VectorAllgatherV(ctx, *GlobalCommGroup(), input); +} } // namespace xgboost::collective diff --git a/src/collective/allgather.h b/src/collective/allgather.h index 8de9f1984f6f..ca44c3916cc3 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -102,4 +102,115 @@ template return detail::RingAllgatherV(comm, sizes, s_segments, erased_result); } + +template +[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm, + linalg::VectorView data) { + if (!comm.IsDistributed()) { + return Success(); + } + CHECK(data.Contiguous()); + auto erased = common::EraseType(data.Values()); + + auto const& cctx = comm.Ctx(ctx, data.Device()); + auto backend = comm.Backend(data.Device()); + return backend->Allgather(cctx, erased); +} + +/** + * @brief Gather all data from all workers. + * + * @param data The input and output buffer, needs to be pre-allocated by the caller. + */ +template +[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView data) { + auto const& cg = *GlobalCommGroup(); + if (data.Size() % cg.World() != 0) { + return Fail("The total number of elements should be multiple of the number of workers."); + } + return Allgather(ctx, cg, data); +} + +template +[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm, + linalg::VectorView data, + std::vector* recv_segments, + HostDeviceVector* recv) { + if (!comm.IsDistributed()) { + return Success(); + } + std::vector sizes(comm.World(), 0); + sizes[comm.Rank()] = data.Values().size_bytes(); + auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()}); + auto rc = comm.Backend(DeviceOrd::CPU()) + ->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes); + if (!rc.OK()) { + return rc; + } + + recv_segments->resize(sizes.size() + 1); + detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()}); + auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL); + recv->SetDevice(data.Device()); + recv->Resize(total_bytes); + + auto s_segments = common::Span{recv_segments->data(), recv_segments->size()}; + + auto backend = comm.Backend(data.Device()); + auto erased = common::EraseType(data.Values()); + + return backend->AllgatherV( + comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments, + data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast); +} + +/** + * @brief Allgather with variable length data. + * + * @param data The input data. + * @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are + * from the first worker, [2, 5) elements are from the second one. + * @param recv The buffer storing the result. + */ +template +[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView data, + std::vector* recv_segments, + HostDeviceVector* recv) { + return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv); +} + +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, CommGroup const& comm, std::vector> const& input); + +/** + * @brief Gathers variable-length data from all processes and distributes it to all processes. + * + * @param inputs All the inputs from the local worker. The number of inputs can vary + * across different workers. Along with which, the size of each vector in + * the input can also vary. + * + * @return The AllgatherV result, containing vectors from all workers. + */ +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, std::vector> const& input); + +/** + * @brief Gathers variable-length strings from all processes and distributes them to all processes. + * @param input Variable-length list of variable-length strings. + */ +[[nodiscard]] inline Result AllgatherStrings(std::vector const& input, + std::vector* p_result) { + std::vector> inputs(input.size()); + for (std::size_t i = 0; i < input.size(); ++i) { + inputs[i] = {input[i].cbegin(), input[i].cend()}; + } + Context ctx; + auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs); + auto& result = *p_result; + result.resize(out.size()); + for (std::size_t i = 0; i < out.size(); ++i) { + result[i] = {out[i].cbegin(), out[i].cend()}; + } + return Success(); +} } // namespace xgboost::collective diff --git a/src/collective/allreduce.cc b/src/collective/allreduce.cc index d9cf8b8283cc..55c5c8854d7f 100644 --- a/src/collective/allreduce.cc +++ b/src/collective/allreduce.cc @@ -68,39 +68,35 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, auto s_buf = common::Span{buffer.data(), buffer.size()}; for (std::int32_t r = 0; r < world - 1; ++r) { - // send to ring next - auto send_rank = (rank + world - r) % world; - auto send_off = send_rank * n_bytes_in_seg; + common::Span seg, recv_seg; + auto rc = Success() << [&] { + // send to ring next + auto send_rank = (rank + world - r) % world; + auto send_off = send_rank * n_bytes_in_seg; - bool is_last_segment = send_rank == (world - 1); + bool is_last_segment = send_rank == (world - 1); - auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg; - auto send_seg = data.subspan(send_off, seg_nbytes); + auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg; + CHECK_EQ(seg_nbytes % sizeof(T), 0); - auto rc = next_ch->SendAll(send_seg); - if (!rc.OK()) { - return rc; - } - - // receive from ring prev - auto recv_rank = (rank + world - r - 1) % world; - auto recv_off = recv_rank * n_bytes_in_seg; + auto send_seg = data.subspan(send_off, seg_nbytes); + return next_ch->SendAll(send_seg); + } << [&] { + // receive from ring prev + auto recv_rank = (rank + world - r - 1) % world; + auto recv_off = recv_rank * n_bytes_in_seg; - is_last_segment = recv_rank == (world - 1); + bool is_last_segment = recv_rank == (world - 1); - seg_nbytes = is_last_segment ? data.size_bytes() - recv_off : n_bytes_in_seg; - CHECK_EQ(seg_nbytes % sizeof(T), 0); - auto recv_seg = data.subspan(recv_off, seg_nbytes); - auto seg = s_buf.subspan(0, recv_seg.size()); + auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg; + CHECK_EQ(seg_nbytes % sizeof(T), 0); - rc = std::move(rc) << [&] { + recv_seg = data.subspan(recv_off, seg_nbytes); + seg = s_buf.subspan(0, recv_seg.size()); return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); }; - if (!rc.OK()) { - return rc; - } // accumulate to recv_seg CHECK_EQ(seg.size(), recv_seg.size()); diff --git a/src/collective/allreduce.h b/src/collective/allreduce.h index 0c94d11cc35d..3e88cca112cb 100644 --- a/src/collective/allreduce.h +++ b/src/collective/allreduce.h @@ -1,15 +1,18 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int8_t #include // for function #include // for is_invocable_v, enable_if_t +#include // for vector #include "../common/type.h" // for EraseType, RestoreType -#include "../data/array_interface.h" // for ArrayInterfaceHandler +#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler #include "comm.h" // for Comm, RestoreType +#include "comm_group.h" // for GlobalCommGroup #include "xgboost/collective/result.h" // for Result +#include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { @@ -27,8 +30,7 @@ std::enable_if_t, common::Span> auto erased = common::EraseType(data); auto type = ToDType::kType; - auto erased_fn = [type, redop](common::Span lhs, - common::Span out) { + auto erased_fn = [redop](common::Span lhs, common::Span out) { CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction."; auto lhs_t = common::RestoreType(lhs); auto rhs_t = common::RestoreType(out); @@ -37,4 +39,40 @@ std::enable_if_t, common::Span> return cpu_impl::RingAllreduce(comm, erased, erased_fn, type); } + +template +[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm, + linalg::TensorView data, Op op) { + if (!comm.IsDistributed()) { + return Success(); + } + CHECK(data.Contiguous()); + auto erased = common::EraseType(data.Values()); + auto type = ToDType::kType; + + auto backend = comm.Backend(data.Device()); + return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op); +} + +template +[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView data, Op op) { + return Allreduce(ctx, *GlobalCommGroup(), data, op); +} + +/** + * @brief Specialization for std::vector. + */ +template +[[nodiscard]] Result Allreduce(Context const* ctx, std::vector* data, Op op) { + return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op); +} + +/** + * @brief Specialization for scalar value. + */ +template +[[nodiscard]] std::enable_if_t && std::is_trivial_v, Result> +Allreduce(Context const* ctx, T* data, Op op) { + return Allreduce(ctx, linalg::MakeVec(data, 1), op); +} } // namespace xgboost::collective diff --git a/src/collective/broadcast.h b/src/collective/broadcast.h index 28db83815cd4..61cab8cdd8f6 100644 --- a/src/collective/broadcast.h +++ b/src/collective/broadcast.h @@ -1,11 +1,15 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int32_t, int8_t -#include "comm.h" // for Comm -#include "xgboost/collective/result.h" // for +#include "../common/type.h" +#include "comm.h" // for Comm, EraseType +#include "comm_group.h" // for CommGroup +#include "xgboost/collective/result.h" // for Result +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for VectorView #include "xgboost/span.h" // for Span namespace xgboost::collective { @@ -23,4 +27,21 @@ template common::Span{reinterpret_cast(data.data()), n_total_bytes}; return cpu_impl::Broadcast(comm, erased, root); } + +template +[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm, + linalg::VectorView data, std::int32_t root) { + if (!comm.IsDistributed()) { + return Success(); + } + CHECK(data.Contiguous()); + auto erased = common::EraseType(data.Values()); + auto backend = comm.Backend(data.Device()); + return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root); +} + +template +[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView data, std::int32_t root) { + return Broadcast(ctx, *GlobalCommGroup(), data, root); +} } // namespace xgboost::collective diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 8788a2436a46..6566f28fad91 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -27,7 +27,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr stub, std::shared ncclUniqueId id; if (comm.Rank() == kRootRank) { auto rc = stub->GetUniqueId(&id); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); } auto rc = coll->Broadcast( comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, kRootRank); @@ -81,8 +81,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p GetCudaUUID(s_this_uuid, ctx->Device()); auto rc = pimpl->Allgather(root, common::EraseType(s_uuid)); - - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); std::vector> converted(root.World()); std::size_t j = 0; @@ -103,7 +102,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p [&] { return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()); }; - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t r = 0; r < root.World(); ++r) { this->channels_.emplace_back( @@ -114,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p NCCLComm::~NCCLComm() { if (nccl_comm_) { auto rc = stub_->CommDestroy(nccl_comm_); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); } } } // namespace xgboost::collective diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 9223302aabec..f4fce42f84f8 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -12,7 +12,6 @@ #include // make_transform_output_iterator #include #include -#include #include #include #include diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index fe3771924043..cf1043ddb399 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -8,6 +8,7 @@ #define COMMON_HIST_UTIL_CUH_ #include +#include // for sort #include // for size_t diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 718474a3e87a..f9a3819ad6e0 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -1,13 +1,13 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include +#include // for sort #include #include #include -#include #include // for size_t #include #include diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index b1a2e0ded333..0bbe5e223d2f 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -71,7 +71,7 @@ target_include_directories(testxgboost ${xgboost_SOURCE_DIR}/rabit/include) target_link_libraries(testxgboost PRIVATE - ${GTEST_LIBRARIES}) + GTest::gtest GTest::gmock) set_output_directory(testxgboost ${xgboost_BINARY_DIR}) diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index b6158693bf1c..b25db54cbed1 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include // for ASSERT_EQ #include // for Span, oper... @@ -35,7 +35,7 @@ class Worker : public WorkerForTest { data[comm_.Rank()] = comm_.Rank(); auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t r = 0; r < comm_.World(); ++r) { ASSERT_EQ(data[r], r); @@ -52,7 +52,7 @@ class Worker : public WorkerForTest { std::iota(seg.begin(), seg.end(), comm_.Rank()); auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t r = 0; r < comm_.World(); ++r) { auto seg = s_data.subspan(r * n, n); @@ -81,7 +81,7 @@ class Worker : public WorkerForTest { std::vector data(comm_.Rank() + 1, comm_.Rank()); std::vector result; auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2); CheckV(result); } @@ -91,7 +91,7 @@ class Worker : public WorkerForTest { std::int32_t n{comm_.Rank()}; std::vector result; auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t i = 0; i < comm_.World(); ++i) { ASSERT_EQ(result[i], i); } @@ -105,7 +105,7 @@ class Worker : public WorkerForTest { std::vector sizes(comm_.World(), 0); sizes[comm_.Rank()] = s_data.size_bytes(); auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); std::shared_ptr pcoll{new Coll{}}; std::vector recv_segments(comm_.World() + 1, 0); diff --git a/tests/cpp/collective/test_allgather.cu b/tests/cpp/collective/test_allgather.cu index 98ece7d17a7a..f145681da46a 100644 --- a/tests/cpp/collective/test_allgather.cu +++ b/tests/cpp/collective/test_allgather.cu @@ -34,7 +34,7 @@ class Worker : public NCCLWorkerForTest { std::vector sizes(comm_.World(), -1); sizes[comm_.Rank()] = s_data.size_bytes(); auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); // create result dh::device_vector result(comm_.World(), -1); auto s_result = common::EraseType(dh::ToSpan(result)); @@ -42,7 +42,7 @@ class Worker : public NCCLWorkerForTest { std::vector recv_seg(nccl_comm_->World() + 1, 0); rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t i = 0; i < comm_.World(); ++i) { ASSERT_EQ(result[i], i); @@ -58,7 +58,7 @@ class Worker : public NCCLWorkerForTest { std::vector sizes(nccl_comm_->World(), 0); sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes(); auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0); // create result dh::device_vector result(n_bytes / sizeof(std::int32_t), -1); @@ -67,7 +67,7 @@ class Worker : public NCCLWorkerForTest { std::vector recv_seg(nccl_comm_->World() + 1, 0); rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); // check segment size if (algo != AllgatherVAlgo::kBcast) { auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()]; diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 457594cd97aa..13a6ca656359 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -59,7 +59,7 @@ class AllreduceWorker : public WorkerForTest { auto pcoll = std::shared_ptr{new Coll{}}; auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (auto v : data) { ASSERT_EQ(v, ~std::uint32_t{0}); } diff --git a/tests/cpp/collective/test_allreduce.cu b/tests/cpp/collective/test_allreduce.cu index f7e11dec2d8f..8bda1e0de10e 100644 --- a/tests/cpp/collective/test_allreduce.cu +++ b/tests/cpp/collective/test_allreduce.cu @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include @@ -24,7 +24,7 @@ class Worker : public NCCLWorkerForTest { data[comm_.Rank()] = ~std::uint32_t{0}; auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); thrust::host_vector h_data(data.size()); thrust::copy(data.cbegin(), data.cend(), h_data.begin()); for (auto v : h_data) { @@ -36,7 +36,7 @@ class Worker : public NCCLWorkerForTest { dh::device_vector data(314, 1.5); auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), ArrayInterfaceHandler::kF8, Op::kSum); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::size_t i = 0; i < data.size(); ++i) { auto v = data[i]; ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i; diff --git a/tests/cpp/collective/test_broadcast.cc b/tests/cpp/collective/test_broadcast.cc index 4d0d87e93ae0..1b1d73428be1 100644 --- a/tests/cpp/collective/test_broadcast.cc +++ b/tests/cpp/collective/test_broadcast.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include #include @@ -10,7 +10,6 @@ #include // for vector #include "../../../src/collective/broadcast.h" // for Broadcast -#include "../../../src/collective/tracker.h" // for GetHostAddress #include "test_worker.h" // for WorkerForTest, TestDistributed namespace xgboost::collective { @@ -24,14 +23,14 @@ class Worker : public WorkerForTest { // basic test std::vector data(1, comm_.Rank()); auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(data[0], r); } for (std::int32_t r = 0; r < comm_.World(); ++r) { std::vector data(1 << 16, comm_.Rank()); auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(data[0], r); } } @@ -41,11 +40,11 @@ class BroadcastTest : public SocketTest {}; } // namespace TEST_F(BroadcastTest, Basic) { - std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + std::int32_t n_workers = std::min(2u, std::thread::hardware_concurrency()); TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { Worker worker{host, port, timeout, n_workers, r}; worker.Run(); }); -} // namespace +} } // namespace xgboost::collective diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 1d10a48ad1a5..4178e55d8fd8 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -1,14 +1,16 @@ -/*! - * Copyright 2017-2021 XGBoost contributors +/** + * Copyright 2017-2024, XGBoost contributors */ +#include +#include // for is_sorted +#include + #include #include -#include #include -#include + #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/quantile.h" -#include "../helpers.h" #include "gtest/gtest.h" TEST(SumReduce, Test) { diff --git a/tests/cpp/common/test_io.cc b/tests/cpp/common/test_io.cc index 4c4d4efe035b..face21851e4f 100644 --- a/tests/cpp/common/test_io.cc +++ b/tests/cpp/common/test_io.cc @@ -1,10 +1,11 @@ /** - * Copyright 2019-2023, XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include // for size_t #include // for ofstream +#include // for iota #include "../../../src/common/io.h" #include "../filesystem.h" // dmlc::TemporaryDirectory diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 3ee041a339ed..e144bdc45b9f 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -4,10 +4,10 @@ #include #include -#include // for back_inserter +#include // for numeric_limits #include +#include // for iota -#include "../../../src/common/charconv.h" #include "../../../src/common/io.h" #include "../../../src/common/json_utils.h" #include "../../../src/common/threading_utils.h" // for ParallelFor