Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device communicator dependency #9346

Merged
merged 4 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/collective/communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL
if (type_ != CommunicatorType::kFederated) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal));
} else {
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
#endif
}
return device_communicator_.get();
Expand Down
33 changes: 14 additions & 19 deletions src/collective/device_communicator_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,53 @@ namespace collective {

class DeviceCommunicatorAdapter : public DeviceCommunicator {
public:
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
explicit DeviceCommunicatorAdapter(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
}

~DeviceCommunicatorAdapter() override = default;

void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * GetTypeSize(data_type);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}

void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();

segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);

host_buffer_.reserve(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size; ++i) {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
if (i == rank_) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
cudaMemcpyDefault));
}
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes;
}
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
Expand All @@ -76,7 +70,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {

private:
int const device_ordinal_;
Communicator *communicator_;
int const world_size_;
int const rank_;
/// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{};
};
Expand Down
52 changes: 21 additions & 31 deletions src/collective/nccl_device_communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,24 @@
namespace xgboost {
namespace collective {

NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}

int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();

if (world == 1) {
if (world_size_ == 1) {
return;
}

std::vector<uint64_t> uuids(world * kUuidLength, 0);
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);

// TODO(rongou): replace this with allgather.
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);

std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
Expand All @@ -41,18 +34,18 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);

CHECK_EQ(n_uniques, world)
CHECK_EQ(n_uniques, world_size_)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";

nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}

NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
if (cuda_stream_) {
Expand Down Expand Up @@ -139,9 +132,8 @@ void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func,

void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
auto const world_size = communicator_->GetWorldSize();
auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size);
dh::caching_device_vector<char> buffer(size * world_size_);
auto *device_buffer = buffer.data().get();

// First gather data from all the workers.
Expand All @@ -152,15 +144,15 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size, size,
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size, size,
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size, size,
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
cuda_stream_);
break;
default:
Expand All @@ -170,7 +162,7 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si

void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}

Expand All @@ -189,24 +181,22 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();

segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);

size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size; ++i) {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
Expand All @@ -216,7 +206,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
}

void NcclDeviceCommunicator::Synchronize() {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
Expand Down
10 changes: 5 additions & 5 deletions src/collective/nccl_device_communicator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace collective {

class NcclDeviceCommunicator : public DeviceCommunicator {
public:
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator);
explicit NcclDeviceCommunicator(int device_ordinal);
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
Expand Down Expand Up @@ -49,19 +49,19 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (communicator_->GetRank() == kRootRank) {
if (rank_ == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
static_cast<int>(kRootRank));
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
}

void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op);

int const device_ordinal_;
Communicator *communicator_;
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
cudaStream_t cuda_stream_{};
ncclUniqueId nccl_unique_id_{};
Expand Down
7 changes: 1 addition & 6 deletions tests/cpp/collective/test_nccl_device_communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@ namespace xgboost {
namespace collective {

TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { NcclDeviceCommunicator comm{-1, nullptr}; };
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) {
auto construct = []() { NcclDeviceCommunicator comm{0, nullptr}; };
auto construct = []() { NcclDeviceCommunicator comm{-1}; };
EXPECT_THROW(construct(), dmlc::Error);
}

Expand Down
9 changes: 8 additions & 1 deletion tests/cpp/plugin/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ class ServerForTest {
}

~ServerForTest() {
using namespace std::chrono_literals;
while (!server_) {
std::this_thread::sleep_for(100ms);
}
server_->Shutdown();
while (!server_thread_) {
std::this_thread::sleep_for(100ms);
}
server_thread_->join();
}

Expand All @@ -56,7 +63,7 @@ class BaseFederatedTest : public ::testing::Test {

void TearDown() override { server_.reset(nullptr); }

static int constexpr kWorldSize{3};
static int constexpr kWorldSize{2};
std::unique_ptr<ServerForTest> server_;
};

Expand Down
Loading