Skip to content

Commit

Permalink
[coll] Add global functions. (#10203)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Apr 18, 2024
1 parent 551fa6e commit 3f64b4f
Show file tree
Hide file tree
Showing 21 changed files with 283 additions and 69 deletions.
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
Expand Down
2 changes: 1 addition & 1 deletion src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ template <typename T>
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
std::array<T, 2> results{dividend, divisor};
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
collective::SafeColl(rc);
SafeColl(rc);
std::tie(dividend, divisor) = std::tuple_cat(results);
if (divisor <= 0) {
return std::numeric_limits<T>::quiet_NaN();
Expand Down
47 changes: 46 additions & 1 deletion src/collective/allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
bool is_last_segment = send_rank == (world - 1);
auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
auto send_seg = data.subspan(send_off, send_nbytes);
CHECK_NE(send_seg.size(), 0);
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
} << [&] {
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size;
bool is_last_segment = recv_rank == (world - 1);
auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
auto recv_seg = data.subspan(recv_off, recv_nbytes);
CHECK_NE(recv_seg.size(), 0);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] {
return prev_ch->Block();
return comm.Block();
};
if (!rc.OK()) {
return rc;
Expand Down Expand Up @@ -106,4 +108,47 @@ namespace detail {
return comm.Block();
}
} // namespace detail

[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input) {
auto n_inputs = input.size();
std::vector<std::int64_t> sizes(n_inputs);
std::transform(input.cbegin(), input.cend(), sizes.begin(),
[](auto const& vec) { return vec.size(); });

std::vector<std::int64_t> recv_segments(comm.World() + 1, 0);

HostDeviceVector<std::int8_t> recv;
auto rc =
AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv);
SafeColl(rc);

auto global_sizes = common::RestoreType<std::int64_t const>(recv.ConstHostSpan());
std::vector<std::int64_t> offset(global_sizes.size() + 1);
offset[0] = 0;
for (std::size_t i = 1; i < offset.size(); i++) {
offset[i] = offset[i - 1] + global_sizes[i - 1];
}

std::vector<char> collected;
for (auto const& vec : input) {
collected.insert(collected.end(), vec.cbegin(), vec.cend());
}
rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments,
&recv);
SafeColl(rc);
auto out = common::RestoreType<char const>(recv.ConstHostSpan());

std::vector<std::vector<char>> result;
for (std::size_t i = 1; i < offset.size(); ++i) {
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
result.emplace_back(std::move(local));
}
return result;
}

[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input) {
return VectorAllgatherV(ctx, *GlobalCommGroup(), input);
}
} // namespace xgboost::collective
111 changes: 111 additions & 0 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,115 @@ template <typename T>

return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
}

template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());

auto const& cctx = comm.Ctx(ctx, data.Device());
auto backend = comm.Backend(data.Device());
return backend->Allgather(cctx, erased);
}

/**
* @brief Gather all data from all workers.
*
* @param data The input and output buffer, needs to be pre-allocated by the caller.
*/
template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView<T> data) {
auto const& cg = *GlobalCommGroup();
if (data.Size() % cg.World() != 0) {
return Fail("The total number of elements should be multiple of the number of workers.");
}
return Allgather(ctx, cg, data);
}

template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
if (!comm.IsDistributed()) {
return Success();
}
std::vector<std::int64_t> sizes(comm.World(), 0);
sizes[comm.Rank()] = data.Values().size_bytes();
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
auto rc = comm.Backend(DeviceOrd::CPU())
->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
if (!rc.OK()) {
return rc;
}

recv_segments->resize(sizes.size() + 1);
detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()});
auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL);
recv->SetDevice(data.Device());
recv->Resize(total_bytes);

auto s_segments = common::Span{recv_segments->data(), recv_segments->size()};

auto backend = comm.Backend(data.Device());
auto erased = common::EraseType(data.Values());

return backend->AllgatherV(
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast);
}

/**
* @brief Allgather with variable length data.
*
* @param data The input data.
* @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are
* from the first worker, [2, 5) elements are from the second one.
* @param recv The buffer storing the result.
*/
template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv);
}

[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input);

/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* @param inputs All the inputs from the local worker. The number of inputs can vary
* across different workers. Along with which, the size of each vector in
* the input can also vary.
*
* @return The AllgatherV result, containing vectors from all workers.
*/
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input);

/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
[[nodiscard]] inline Result AllgatherStrings(std::vector<std::string> const& input,
std::vector<std::string>* p_result) {
std::vector<std::vector<char>> inputs(input.size());
for (std::size_t i = 0; i < input.size(); ++i) {
inputs[i] = {input[i].cbegin(), input[i].cend()};
}
Context ctx;
auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs);
auto& result = *p_result;
result.resize(out.size());
for (std::size_t i = 0; i < out.size(); ++i) {
result[i] = {out[i].cbegin(), out[i].cend()};
}
return Success();
}
} // namespace xgboost::collective
42 changes: 19 additions & 23 deletions src/collective/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,39 +68,35 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto s_buf = common::Span{buffer.data(), buffer.size()};

for (std::int32_t r = 0; r < world - 1; ++r) {
// send to ring next
auto send_rank = (rank + world - r) % world;
auto send_off = send_rank * n_bytes_in_seg;
common::Span<std::int8_t> seg, recv_seg;
auto rc = Success() << [&] {
// send to ring next
auto send_rank = (rank + world - r) % world;
auto send_off = send_rank * n_bytes_in_seg;

bool is_last_segment = send_rank == (world - 1);
bool is_last_segment = send_rank == (world - 1);

auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
auto send_seg = data.subspan(send_off, seg_nbytes);
auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);

auto rc = next_ch->SendAll(send_seg);
if (!rc.OK()) {
return rc;
}

// receive from ring prev
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = recv_rank * n_bytes_in_seg;
auto send_seg = data.subspan(send_off, seg_nbytes);
return next_ch->SendAll(send_seg);
} << [&] {
// receive from ring prev
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = recv_rank * n_bytes_in_seg;

is_last_segment = recv_rank == (world - 1);
bool is_last_segment = recv_rank == (world - 1);

seg_nbytes = is_last_segment ? data.size_bytes() - recv_off : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());
auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);

rc = std::move(rc) << [&] {
recv_seg = data.subspan(recv_off, seg_nbytes);
seg = s_buf.subspan(0, recv_seg.size());
return prev_ch->RecvAll(seg);
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return rc;
}

// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());
Expand Down
46 changes: 42 additions & 4 deletions src/collective/allreduce.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t
#include <functional> // for function
#include <type_traits> // for is_invocable_v, enable_if_t
#include <vector> // for vector

#include "../common/type.h" // for EraseType, RestoreType
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler
#include "comm.h" // for Comm, RestoreType
#include "comm_group.h" // for GlobalCommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span

namespace xgboost::collective {
Expand All @@ -27,8 +30,7 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
auto erased = common::EraseType(data);
auto type = ToDType<T>::kType;

auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
common::Span<std::int8_t> out) {
auto erased_fn = [redop](common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = common::RestoreType<T const>(lhs);
auto rhs_t = common::RestoreType<T>(out);
Expand All @@ -37,4 +39,40 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>

return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
}

template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm,
linalg::TensorView<T, kDim> data, Op op) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto type = ToDType<T>::kType;

auto backend = comm.Backend(data.Device());
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
}

template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView<T, kDim> data, Op op) {
return Allreduce(ctx, *GlobalCommGroup(), data, op);
}

/**
* @brief Specialization for std::vector.
*/
template <typename T, typename Alloc>
[[nodiscard]] Result Allreduce(Context const* ctx, std::vector<T, Alloc>* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op);
}

/**
* @brief Specialization for scalar value.
*/
template <typename T>
[[nodiscard]] std::enable_if_t<std::is_standard_layout_v<T> && std::is_trivial_v<T>, Result>
Allreduce(Context const* ctx, T* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data, 1), op);
}
} // namespace xgboost::collective
27 changes: 24 additions & 3 deletions src/collective/broadcast.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t, int8_t

#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for
#include "../common/type.h"
#include "comm.h" // for Comm, EraseType
#include "comm_group.h" // for CommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for VectorView
#include "xgboost/span.h" // for Span

namespace xgboost::collective {
Expand All @@ -23,4 +27,21 @@ template <typename T>
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
return cpu_impl::Broadcast(comm, erased, root);
}

template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data, std::int32_t root) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto backend = comm.Backend(data.Device());
return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root);
}

template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView<T> data, std::int32_t root) {
return Broadcast(ctx, *GlobalCommGroup(), data, root);
}
} // namespace xgboost::collective
Loading

0 comments on commit 3f64b4f

Please sign in to comment.