Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve allgather functions #9649

Merged
merged 18 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions plugin/federated/federated.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package xgboost.federated;

service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}
Expand Down Expand Up @@ -42,6 +43,17 @@ message AllgatherReply {
bytes receive_buffer = 1;
}

message AllgatherVRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}

message AllgatherVReply {
bytes receive_buffer = 1;
}

message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
Expand Down
23 changes: 21 additions & 2 deletions plugin/federated/federated_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class FederatedClient {
}()},
rank_{rank} {}

std::string Allgather(std::string const &send_buffer) {
std::string Allgather(std::string_view send_buffer) {
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
request.set_send_buffer(send_buffer.data(), send_buffer.size());

AllgatherReply reply;
grpc::ClientContext context;
Expand All @@ -63,6 +63,25 @@ class FederatedClient {
}
}

std::string AllgatherV(std::string_view send_buffer) {
AllgatherVRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer.data(), send_buffer.size());

AllgatherVReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub_->AllgatherV(&context, request, &reply);

if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("AllgatherV RPC failed");
}
}

std::string Allreduce(std::string const &send_buffer, DataType data_type,
ReduceOperation reduce_operation) {
AllreduceRequest request;
Expand Down
19 changes: 12 additions & 7 deletions plugin/federated/federated_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,19 @@ class FederatedCommunicator : public Communicator {
[[nodiscard]] bool IsFederated() const override { return true; }

/**
* \brief Perform in-place allgather.
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param size Number of bytes to be gathered.
* \brief Perform allgather.
* \param input Buffer for sending data.
*/
std::string AllGather(std::string_view input) override {
return client_->Allgather(input);
}

/**
* \brief Perform variable-length allgather.
* \param input Buffer for sending data.
*/
void AllGather(void *send_receive_buffer, std::size_t size) override {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
auto const received = client_->Allgather(send_buffer);
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
std::string AllGatherV(std::string_view input) override {
return client_->AllgatherV(input);
}

/**
Expand Down
13 changes: 10 additions & 3 deletions plugin/federated/federated_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest
return grpc::Status::OK;
}

grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
AllgatherVReply* reply) {
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}

grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
AllreduceReply* reply) {
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
Expand All @@ -36,8 +43,8 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
return grpc::Status::OK;
}

void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file) {
void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};

Expand All @@ -59,7 +66,7 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
server->Wait();
}

void RunInsecureServer(int port, int world_size) {
void RunInsecureServer(int port, std::size_t world_size) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};

Expand Down
11 changes: 7 additions & 4 deletions plugin/federated/federated_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ namespace federated {

class FederatedService final : public Federated::Service {
public:
explicit FederatedService(int const world_size) : handler_{world_size} {}
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}

grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;

grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
AllgatherVReply* reply) override;

grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;

Expand All @@ -27,10 +30,10 @@ class FederatedService final : public Federated::Service {
xgboost::collective::InMemoryHandler handler_;
};

void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file);
void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file);

void RunInsecureServer(int port, int world_size);
void RunInsecureServer(int port, std::size_t world_size);

} // namespace federated
} // namespace xgboost
4 changes: 2 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1724,15 +1724,15 @@ XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int
}

#if defined(XGBOOST_USE_FEDERATED)
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path,
char const *server_cert_path, char const *client_cert_path) {
API_BEGIN();
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
API_END();
}

// Run a server without SSL for local testing.
XGB_DLL int XGBRunInsecureFederatedServer(int port, int world_size) {
XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) {
API_BEGIN();
federated::RunInsecureServer(port, world_size);
API_END();
Expand Down
106 changes: 84 additions & 22 deletions src/collective/communicator-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ namespace collective {
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
*/
inline void Init(Json const& config) {
Communicator::Init(config);
}
inline void Init(Json const &config) { Communicator::Init(config); }

/*!
* \brief Finalize the collective communicator.
Expand Down Expand Up @@ -141,17 +139,89 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
}
}

/**
* @brief Gathers a single value all processes and distributes the result to all processes.
*
* @param input The single value.
*/
template <typename T>
inline std::vector<T> Allgather(T const &input) {
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}

/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size, and input data has been sliced into the
* corresponding position.
* This assumes all ranks have the same size.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param input Buffer storing the data.
*/
inline void Allgather(void *send_receive_buffer, std::size_t size) {
Communicator::Get()->AllGather(send_receive_buffer, size);
template <typename T>
inline std::vector<T> Allgather(std::vector<T> const &input) {
if (input.empty()) {
return input;
}
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}

/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
template <typename T>
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGatherV(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
if (!output.empty()) {
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
}
return result;
}

/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
std::size_t total_size{0};
for (auto const &s : input) {
total_size += s.length() + 1; // +1 for null-terminators
}
std::string flat_string;
flat_string.reserve(total_size);
for (auto const &s : input) {
flat_string.append(s);
flat_string.push_back('\0'); // Append a null-terminator after each string
}

auto const output = Communicator::Get()->AllGatherV(flat_string);

std::vector<std::string> result;
std::size_t start_index = 0;
// Iterate through the output, find each null-terminated substring.
for (std::size_t i = 0; i < output.size(); i++) {
if (output[i] == '\0') {
// Construct a std::string from the char* substring
result.emplace_back(&output[start_index]);
// Move to the next substring
start_index = i + 1;
}
}
return result;
}

/*!
Expand Down Expand Up @@ -226,7 +296,7 @@ inline void Allreduce(double *send_receive_buffer, size_t count) {
}

template <typename T>
struct AllgatherVResult {
struct SpecialAllgatherVResult {
std::vector<std::size_t> offsets;
std::vector<std::size_t> sizes;
std::vector<T> result;
Expand All @@ -241,14 +311,10 @@ struct AllgatherVResult {
* @param sizes Sizes of each input.
*/
template <typename T>
inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
std::vector<std::size_t> const &sizes) {
auto num_inputs = sizes.size();

inline SpecialAllgatherVResult<T> SpecialAllgatherV(std::vector<T> const &inputs,
std::vector<std::size_t> const &sizes) {
// Gather the sizes across all workers.
std::vector<std::size_t> all_sizes(num_inputs * GetWorldSize());
std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs * GetRank());
collective::Allgather(all_sizes.data(), all_sizes.size() * sizeof(std::size_t));
auto const all_sizes = Allgather(sizes);

// Calculate input offsets (std::exclusive_scan).
std::vector<std::size_t> offsets(all_sizes.size());
Expand All @@ -257,11 +323,7 @@ inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
}

// Gather all the inputs.
auto total_input_size = offsets.back() + all_sizes.back();
std::vector<T> all_inputs(total_input_size);
std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs * GetRank()]);
// We cannot use allgather here, since each worker might have a different size.
Allreduce<Operation::kMax>(all_inputs.data(), all_inputs.size());
auto const all_inputs = AllgatherV(inputs);

return {offsets, all_sizes, all_inputs};
}
Expand Down
14 changes: 9 additions & 5 deletions src/collective/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,17 @@ class Communicator {
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size, and input data has been sliced into the
* corresponding position.
* This assumes all ranks have the same size.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param input Buffer storing the data.
*/
virtual std::string AllGather(std::string_view input) = 0;

/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0;
virtual std::string AllGatherV(std::string_view input) = 0;

/**
* @brief Combines values from all processes and distributes the result back to all processes.
Expand Down
10 changes: 4 additions & 6 deletions src/collective/device_communicator_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
host_buffer_.resize(send_size * world_size_);
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
cudaMemcpyDefault));
Allgather(host_buffer_.data(), host_buffer_.size());
dh::safe_cuda(
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
host_buffer_.resize(send_size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
auto const output = Allgather(host_buffer_);
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
}

void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
Expand Down
13 changes: 9 additions & 4 deletions src/collective/in_memory_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,16 @@ class InMemoryCommunicator : public Communicator {
bool IsDistributed() const override { return true; }
bool IsFederated() const override { return false; }

void AllGather(void* in_out, std::size_t size) override {
std::string AllGather(std::string_view input) override {
std::string output;
handler_.Allgather(static_cast<const char*>(in_out), size, &output, sequence_number_++,
GetRank());
output.copy(static_cast<char*>(in_out), size);
handler_.Allgather(input.data(), input.size(), &output, sequence_number_++, GetRank());
return output;
}

std::string AllGatherV(std::string_view input) override {
std::string output;
handler_.AllgatherV(input.data(), input.size(), &output, sequence_number_++, GetRank());
return output;
}

void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
Expand Down
Loading
Loading