Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Add nccl. #9726

Merged
merged 8 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions src/collective/allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,23 @@
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <numeric> // for partial_sum
#include <vector> // for vector

#include "broadcast.h"
#include "comm.h" // for Comm, Channel
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span

namespace xgboost::collective::cpu_impl {
namespace xgboost::collective {
namespace cpu_impl {
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch) {
auto world = comm.World();
auto rank = comm.Rank();
CHECK_LT(worker_off, world);
if (world == 1) {
return Success();
}

for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r + worker_off) % world;
Expand All @@ -43,11 +46,29 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
return Success();
}

Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t> recv) {
std::size_t offset = 0;
for (std::int32_t r = 0; r < comm.World(); ++r) {
auto as_bytes = sizes[r];
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
}
} // namespace cpu_impl

namespace detail {
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int64_t const> offset,
common::Span<std::int8_t> erased_result) {
auto world = comm.World();
if (world == 1) {
return Success();
}
auto rank = comm.Rank();

auto prev = BootstrapPrev(rank, comm.World());
Expand All @@ -56,17 +77,6 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);

// get worker offset
CHECK_EQ(world + 1, offset.size());
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);

// copy data
auto current = erased_result.subspan(offset[rank], data.size_bytes());
auto erased_data = EraseType(data);
std::copy_n(erased_data.data(), erased_data.size(), current.data());

for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r) % world;
auto send_off = offset[send_rank];
Expand All @@ -87,4 +97,5 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
}
return comm.Block();
}
} // namespace xgboost::collective::cpu_impl
} // namespace detail
} // namespace xgboost::collective
45 changes: 35 additions & 10 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,46 @@
#include <type_traits> // for remove_cv_t
#include <vector> // for vector

#include "../common/type.h" // for EraseType
#include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
#include "xgboost/linalg.h"
#include "xgboost/span.h" // for Span

namespace xgboost::collective {
namespace cpu_impl {
/**
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
* = 1, then it owns the third segment.
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
* worker_off = 1, then it owns the third segment.
*/
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off,
std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch);

/**
* @brief Implement allgather-v using broadcast.
*
* https://arxiv.org/abs/1812.05964
*/
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t> recv);
} // namespace cpu_impl

namespace detail {
inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> offset) {
// get worker offset
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);
}

// An implementation that's used by both cpu and gpu
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int64_t const> offset,
common::Span<std::int8_t> erased_result);
} // namespace cpu_impl
} // namespace detail

template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
Expand Down Expand Up @@ -68,9 +87,15 @@ template <typename T>
auto h_result = common::Span{result.data(), result.size()};
auto erased_result = common::EraseType(h_result);
auto erased_data = common::EraseType(data);
std::vector<std::int64_t> offset(world + 1);
std::vector<std::int64_t> recv_segments(world + 1);
auto s_segments = common::Span{recv_segments.data(), recv_segments.size()};

// get worker offset
detail::AllgatherVOffset(sizes, s_segments);
// copy data
auto current = erased_result.subspan(recv_segments[rank], data.size_bytes());
std::copy_n(erased_data.data(), erased_data.size(), current.data());

return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
common::Span{offset.data(), offset.size()}, erased_result);
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
}
} // namespace xgboost::collective
52 changes: 36 additions & 16 deletions src/collective/coll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus

#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
#include "xgboost/context.h" // for Context
#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm

namespace xgboost::collective {
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
Op op) {
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type, Op op) {
namespace coll = ::xgboost::collective;

auto redop_fn = [](auto lhs, auto out, auto elem_op) {
Expand Down Expand Up @@ -55,21 +53,43 @@ namespace xgboost::collective {
return comm.Block();
}

[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::int32_t root) {
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root) {
return cpu_impl::Broadcast(comm, data, root);
}

[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::size_t size) {
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) {
return RingAllgather(comm, data, size);
}

[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
common::Span<std::int8_t const> data,
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv) {
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
// get worker offset
detail::AllgatherVOffset(sizes, recv_segments);

// copy data
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
std::copy_n(data.data(), data.size(), current.data());

switch (algo) {
case AllgatherVAlgo::kRing:
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
case AllgatherVAlgo::kBcast:
return cpu_impl::BroadcastAllgatherV(comm, sizes, recv);
default: {
return Fail("Unknown algorithm for allgather-v");
}
}
}

#if !defined(XGBOOST_USE_NCCL)
Coll* Coll::MakeCUDAVar() {
LOG(FATAL) << "NCCL is required for device communication.";
return nullptr;
}
#endif

} // namespace xgboost::collective
Loading
Loading