diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto index d8ef5bd92f43..8450659fd180 100644 --- a/plugin/federated/federated.proto +++ b/plugin/federated/federated.proto @@ -7,6 +7,7 @@ package xgboost.federated; service Federated { rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} + rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {} rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} } @@ -42,6 +43,17 @@ message AllgatherReply { bytes receive_buffer = 1; } +message AllgatherVRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; +} + +message AllgatherVReply { + bytes receive_buffer = 1; +} + message AllreduceRequest { // An incrementing counter that is unique to each round to operations. uint64 sequence_number = 1; diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index d104cb2319b8..ac1fbd57d9dd 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -44,11 +44,11 @@ class FederatedClient { }()}, rank_{rank} {} - std::string Allgather(std::string const &send_buffer) { + std::string Allgather(std::string_view send_buffer) { AllgatherRequest request; request.set_sequence_number(sequence_number_++); request.set_rank(rank_); - request.set_send_buffer(send_buffer); + request.set_send_buffer(send_buffer.data(), send_buffer.size()); AllgatherReply reply; grpc::ClientContext context; @@ -63,6 +63,25 @@ class FederatedClient { } } + std::string AllgatherV(std::string_view send_buffer) { + AllgatherVRequest request; + request.set_sequence_number(sequence_number_++); + request.set_rank(rank_); + request.set_send_buffer(send_buffer.data(), send_buffer.size()); + + AllgatherVReply reply; + grpc::ClientContext context; + context.set_wait_for_ready(true); + grpc::Status status = stub_->AllgatherV(&context, request, &reply); + + if (status.ok()) { + return reply.receive_buffer(); + } else { + std::cout << status.error_code() << ": " << status.error_message() << '\n'; + throw std::runtime_error("AllgatherV RPC failed"); + } + } + std::string Allreduce(std::string const &send_buffer, DataType data_type, ReduceOperation reduce_operation) { AllreduceRequest request; diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 996b433cb2ea..46c6b0fda672 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -125,14 +125,19 @@ class FederatedCommunicator : public Communicator { [[nodiscard]] bool IsFederated() const override { return true; } /** - * \brief Perform in-place allgather. - * \param send_receive_buffer Buffer for both sending and receiving data. - * \param size Number of bytes to be gathered. + * \brief Perform allgather. + * \param input Buffer for sending data. + */ + std::string AllGather(std::string_view input) override { + return client_->Allgather(input); + } + + /** + * \brief Perform variable-length allgather. + * \param input Buffer for sending data. */ - void AllGather(void *send_receive_buffer, std::size_t size) override { - std::string const send_buffer(reinterpret_cast(send_receive_buffer), size); - auto const received = client_->Allgather(send_buffer); - received.copy(reinterpret_cast(send_receive_buffer), size); + std::string AllGatherV(std::string_view input) override { + return client_->AllgatherV(input); } /** diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index ae42f6d28920..ad6cf6022e6d 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -19,6 +19,13 @@ grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest return grpc::Status::OK; } +grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request, + AllgatherVReply* reply) { + handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(), + reply->mutable_receive_buffer(), request->sequence_number(), request->rank()); + return grpc::Status::OK; +} + grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request, AllreduceReply* reply) { handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(), @@ -36,8 +43,8 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest return grpc::Status::OK; } -void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, - char const* client_cert_file) { +void RunServer(int port, std::size_t world_size, char const* server_key_file, + char const* server_cert_file, char const* client_cert_file) { std::string const server_address = "0.0.0.0:" + std::to_string(port); FederatedService service{world_size}; @@ -59,7 +66,7 @@ void RunServer(int port, int world_size, char const* server_key_file, char const server->Wait(); } -void RunInsecureServer(int port, int world_size) { +void RunInsecureServer(int port, std::size_t world_size) { std::string const server_address = "0.0.0.0:" + std::to_string(port); FederatedService service{world_size}; diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index 7738248ea729..711ef55880b3 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -12,11 +12,14 @@ namespace federated { class FederatedService final : public Federated::Service { public: - explicit FederatedService(int const world_size) : handler_{world_size} {} + explicit FederatedService(std::size_t const world_size) : handler_{world_size} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override; + grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request, + AllgatherVReply* reply) override; + grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, AllreduceReply* reply) override; @@ -27,10 +30,10 @@ class FederatedService final : public Federated::Service { xgboost::collective::InMemoryHandler handler_; }; -void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, - char const* client_cert_file); +void RunServer(int port, std::size_t world_size, char const* server_key_file, + char const* server_cert_file, char const* client_cert_file); -void RunInsecureServer(int port, int world_size); +void RunInsecureServer(int port, std::size_t world_size); } // namespace federated } // namespace xgboost diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 858047af84ca..4fb6d90ff432 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1724,7 +1724,7 @@ XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int } #if defined(XGBOOST_USE_FEDERATED) -XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path, +XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path, char const *server_cert_path, char const *client_cert_path) { API_BEGIN(); federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path); @@ -1732,7 +1732,7 @@ XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_k } // Run a server without SSL for local testing. -XGB_DLL int XGBRunInsecureFederatedServer(int port, int world_size) { +XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) { API_BEGIN(); federated::RunInsecureServer(port, world_size); API_END(); diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index 59cc4cc45736..c58a9f3bcf83 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -57,9 +57,7 @@ namespace collective { * - federated_client_key: Client key file path. Only needed for the SSL mode. * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. */ -inline void Init(Json const& config) { - Communicator::Init(config); -} +inline void Init(Json const &config) { Communicator::Init(config); } /*! * \brief Finalize the collective communicator. @@ -141,17 +139,89 @@ inline void Broadcast(std::string *sendrecv_data, int root) { } } +/** + * @brief Gathers a single value all processes and distributes the result to all processes. + * + * @param input The single value. + */ +template +inline std::vector Allgather(T const &input) { + std::string_view str_input{reinterpret_cast(&input), sizeof(T)}; + auto const output = Communicator::Get()->AllGather(str_input); + CHECK_EQ(output.size() % sizeof(T), 0); + std::vector result(output.size() / sizeof(T)); + std::memcpy(reinterpret_cast(result.data()), output.data(), output.size()); + return result; +} + /** * @brief Gathers data from all processes and distributes it to all processes. * - * This assumes all ranks have the same size, and input data has been sliced into the - * corresponding position. + * This assumes all ranks have the same size. * - * @param send_receive_buffer Buffer storing the data. - * @param size Size of the data in bytes. + * @param input Buffer storing the data. */ -inline void Allgather(void *send_receive_buffer, std::size_t size) { - Communicator::Get()->AllGather(send_receive_buffer, size); +template +inline std::vector Allgather(std::vector const &input) { + if (input.empty()) { + return input; + } + std::string_view str_input{reinterpret_cast(input.data()), + input.size() * sizeof(T)}; + auto const output = Communicator::Get()->AllGather(str_input); + CHECK_EQ(output.size() % sizeof(T), 0); + std::vector result(output.size() / sizeof(T)); + std::memcpy(reinterpret_cast(result.data()), output.data(), output.size()); + return result; +} + +/** + * @brief Gathers variable-length data from all processes and distributes it to all processes. + * @param input Buffer storing the data. + */ +template +inline std::vector AllgatherV(std::vector const &input) { + std::string_view str_input{reinterpret_cast(input.data()), + input.size() * sizeof(T)}; + auto const output = Communicator::Get()->AllGatherV(str_input); + CHECK_EQ(output.size() % sizeof(T), 0); + std::vector result(output.size() / sizeof(T)); + if (!output.empty()) { + std::memcpy(reinterpret_cast(result.data()), output.data(), output.size()); + } + return result; +} + +/** + * @brief Gathers variable-length strings from all processes and distributes them to all processes. + * @param input Variable-length list of variable-length strings. + */ +inline std::vector AllgatherStrings(std::vector const &input) { + std::size_t total_size{0}; + for (auto const &s : input) { + total_size += s.length() + 1; // +1 for null-terminators + } + std::string flat_string; + flat_string.reserve(total_size); + for (auto const &s : input) { + flat_string.append(s); + flat_string.push_back('\0'); // Append a null-terminator after each string + } + + auto const output = Communicator::Get()->AllGatherV(flat_string); + + std::vector result; + std::size_t start_index = 0; + // Iterate through the output, find each null-terminated substring. + for (std::size_t i = 0; i < output.size(); i++) { + if (output[i] == '\0') { + // Construct a std::string from the char* substring + result.emplace_back(&output[start_index]); + // Move to the next substring + start_index = i + 1; + } + } + return result; } /*! @@ -226,7 +296,7 @@ inline void Allreduce(double *send_receive_buffer, size_t count) { } template -struct AllgatherVResult { +struct SpecialAllgatherVResult { std::vector offsets; std::vector sizes; std::vector result; @@ -241,14 +311,10 @@ struct AllgatherVResult { * @param sizes Sizes of each input. */ template -inline AllgatherVResult AllgatherV(std::vector const &inputs, - std::vector const &sizes) { - auto num_inputs = sizes.size(); - +inline SpecialAllgatherVResult SpecialAllgatherV(std::vector const &inputs, + std::vector const &sizes) { // Gather the sizes across all workers. - std::vector all_sizes(num_inputs * GetWorldSize()); - std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs * GetRank()); - collective::Allgather(all_sizes.data(), all_sizes.size() * sizeof(std::size_t)); + auto const all_sizes = Allgather(sizes); // Calculate input offsets (std::exclusive_scan). std::vector offsets(all_sizes.size()); @@ -257,11 +323,7 @@ inline AllgatherVResult AllgatherV(std::vector const &inputs, } // Gather all the inputs. - auto total_input_size = offsets.back() + all_sizes.back(); - std::vector all_inputs(total_input_size); - std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs * GetRank()]); - // We cannot use allgather here, since each worker might have a different size. - Allreduce(all_inputs.data(), all_inputs.size()); + auto const all_inputs = AllgatherV(inputs); return {offsets, all_sizes, all_inputs}; } diff --git a/src/collective/communicator.h b/src/collective/communicator.h index def9615135df..feb446355b5d 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -125,13 +125,17 @@ class Communicator { /** * @brief Gathers data from all processes and distributes it to all processes. * - * This assumes all ranks have the same size, and input data has been sliced into the - * corresponding position. + * This assumes all ranks have the same size. * - * @param send_receive_buffer Buffer storing the data. - * @param size Size of the data in bytes. + * @param input Buffer storing the data. + */ + virtual std::string AllGather(std::string_view input) = 0; + + /** + * @brief Gathers variable-length data from all processes and distributes it to all processes. + * @param input Buffer storing the data. */ - virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0; + virtual std::string AllGatherV(std::string_view input) = 0; /** * @brief Combines values from all processes and distributes the result back to all processes. diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh index d10b10486693..7d3e836a0ec9 100644 --- a/src/collective/device_communicator_adapter.cuh +++ b/src/collective/device_communicator_adapter.cuh @@ -40,12 +40,10 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { } dh::safe_cuda(cudaSetDevice(device_ordinal_)); - host_buffer_.resize(send_size * world_size_); - dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size, - cudaMemcpyDefault)); - Allgather(host_buffer_.data(), host_buffer_.size()); - dh::safe_cuda( - cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault)); + host_buffer_.resize(send_size); + dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault)); + auto const output = Allgather(host_buffer_); + dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault)); } void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, diff --git a/src/collective/in_memory_communicator.h b/src/collective/in_memory_communicator.h index f41029af1dea..c712d32a8006 100644 --- a/src/collective/in_memory_communicator.h +++ b/src/collective/in_memory_communicator.h @@ -60,11 +60,16 @@ class InMemoryCommunicator : public Communicator { bool IsDistributed() const override { return true; } bool IsFederated() const override { return false; } - void AllGather(void* in_out, std::size_t size) override { + std::string AllGather(std::string_view input) override { std::string output; - handler_.Allgather(static_cast(in_out), size, &output, sequence_number_++, - GetRank()); - output.copy(static_cast(in_out), size); + handler_.Allgather(input.data(), input.size(), &output, sequence_number_++, GetRank()); + return output; + } + + std::string AllGatherV(std::string_view input) override { + std::string output; + handler_.AllgatherV(input.data(), input.size(), &output, sequence_number_++, GetRank()); + return output; } void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override { diff --git a/src/collective/in_memory_handler.cc b/src/collective/in_memory_handler.cc index a45fe3e7dd78..944e5077b068 100644 --- a/src/collective/in_memory_handler.cc +++ b/src/collective/in_memory_handler.cc @@ -16,23 +16,49 @@ class AllgatherFunctor { public: std::string const name{"Allgather"}; - AllgatherFunctor(int world_size, int rank) : world_size_{world_size}, rank_{rank} {} + AllgatherFunctor(std::size_t world_size, std::size_t rank) + : world_size_{world_size}, rank_{rank} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { if (buffer->empty()) { - // Copy the input if this is the first request. - buffer->assign(input, bytes); - } else { - // Splice the input into the common buffer. - auto const per_rank = bytes / world_size_; - auto const index = rank_ * per_rank; - buffer->replace(index, per_rank, input + index, per_rank); + // Resize the buffer if this is the first request. + buffer->resize(bytes * world_size_); + } + + // Splice the input into the common buffer. + buffer->replace(rank_ * bytes, bytes, input, bytes); + } + + private: + std::size_t world_size_; + std::size_t rank_; +}; + +/** + * @brief Functor for variable-length allgather. + */ +class AllgatherVFunctor { + public: + std::string const name{"AllgatherV"}; + + AllgatherVFunctor(std::size_t world_size, std::size_t rank, + std::map* data) + : world_size_{world_size}, rank_{rank}, data_{data} {} + + void operator()(char const* input, std::size_t bytes, std::string* buffer) const { + data_->emplace(rank_, std::string_view{input, bytes}); + if (data_->size() == world_size_) { + for (auto const& kv : *data_) { + buffer->append(kv.second); + } + data_->clear(); } } private: - int world_size_; - int rank_; + std::size_t world_size_; + std::size_t rank_; + std::map* data_; }; /** @@ -154,7 +180,7 @@ class BroadcastFunctor { public: std::string const name{"Broadcast"}; - BroadcastFunctor(int rank, int root) : rank_{rank}, root_{root} {} + BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { if (rank_ == root_) { @@ -164,11 +190,11 @@ class BroadcastFunctor { } private: - int rank_; - int root_; + std::size_t rank_; + std::size_t root_; }; -void InMemoryHandler::Init(int world_size, int) { +void InMemoryHandler::Init(std::size_t world_size, std::size_t) { CHECK(world_size_ < world_size) << "In memory handler already initialized."; std::unique_lock lock(mutex_); @@ -178,7 +204,7 @@ void InMemoryHandler::Init(int world_size, int) { cv_.notify_all(); } -void InMemoryHandler::Shutdown(uint64_t sequence_number, int) { +void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) { CHECK(world_size_ > 0) << "In memory handler already shutdown."; std::unique_lock lock(mutex_); @@ -194,24 +220,30 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, int) { } void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, int rank) { + std::size_t sequence_number, std::size_t rank) { Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank}); } +void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output, + std::size_t sequence_number, std::size_t rank) { + Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_}); +} + void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, int rank, DataType data_type, + std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op) { Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op}); } void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, int rank, int root) { + std::size_t sequence_number, std::size_t rank, std::size_t root) { Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root}); } template void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, int rank, HandlerFunctor const& functor) { + std::size_t sequence_number, std::size_t rank, + HandlerFunctor const& functor) { // Pass through if there is only 1 client. if (world_size_ == 1) { if (input != output->data()) { diff --git a/src/collective/in_memory_handler.h b/src/collective/in_memory_handler.h index 4182c7b3ddb2..f9ac520079fd 100644 --- a/src/collective/in_memory_handler.h +++ b/src/collective/in_memory_handler.h @@ -3,6 +3,7 @@ */ #pragma once #include +#include #include #include "communicator.h" @@ -31,7 +32,7 @@ class InMemoryHandler { * * This is used when the handler only needs to be initialized once with a known world size. */ - explicit InMemoryHandler(int worldSize) : world_size_{worldSize} {} + explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {} /** * @brief Initialize the handler with the world size and rank. @@ -41,7 +42,7 @@ class InMemoryHandler { * This is used when multiple objects/threads are accessing the same handler and need to * initialize it collectively. */ - void Init(int world_size, int rank); + void Init(std::size_t world_size, std::size_t rank); /** * @brief Shut down the handler. @@ -51,7 +52,7 @@ class InMemoryHandler { * This is used when multiple objects/threads are accessing the same handler and need to * shut it down collectively. */ - void Shutdown(uint64_t sequence_number, int rank); + void Shutdown(uint64_t sequence_number, std::size_t rank); /** * @brief Perform allgather. @@ -62,7 +63,18 @@ class InMemoryHandler { * @param rank Index of the worker. */ void Allgather(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, int rank); + std::size_t sequence_number, std::size_t rank); + + /** + * @brief Perform variable-length allgather. + * @param input The input buffer. + * @param bytes Number of bytes in the input buffer. + * @param output The output buffer. + * @param sequence_number Call sequence number. + * @param rank Index of the worker. + */ + void AllgatherV(char const* input, std::size_t bytes, std::string* output, + std::size_t sequence_number, std::size_t rank); /** * @brief Perform allreduce. @@ -75,7 +87,7 @@ class InMemoryHandler { * @param op The reduce operation. */ void Allreduce(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, int rank, DataType data_type, Operation op); + std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op); /** * @brief Perform broadcast. @@ -87,7 +99,7 @@ class InMemoryHandler { * @param root Index of the worker to broadcast from. */ void Broadcast(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, int rank, int root); + std::size_t sequence_number, std::size_t rank, std::size_t root); private: /** @@ -102,15 +114,16 @@ class InMemoryHandler { */ template void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number, - int rank, HandlerFunctor const& functor); + std::size_t rank, HandlerFunctor const& functor); - int world_size_{}; /// Number of workers. - int received_{}; /// Number of calls received with the current sequence. - int sent_{}; /// Number of calls completed with the current sequence. - std::string buffer_{}; /// A shared common buffer. - uint64_t sequence_number_{}; /// Call sequence number. - mutable std::mutex mutex_; /// Lock. - mutable std::condition_variable cv_; /// Conditional variable to wait on. + std::size_t world_size_{}; /// Number of workers. + std::size_t received_{}; /// Number of calls received with the current sequence. + std::size_t sent_{}; /// Number of calls completed with the current sequence. + std::string buffer_{}; /// A shared common buffer. + std::map aux_{}; /// A shared auxiliary map. + uint64_t sequence_number_{}; /// Call sequence number. + mutable std::mutex mutex_; /// Lock. + mutable std::condition_variable cv_; /// Conditional variable to wait on. }; } // namespace collective diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h index 28a0a1cada4d..2d88fd8024d2 100644 --- a/src/collective/noop_communicator.h +++ b/src/collective/noop_communicator.h @@ -17,10 +17,11 @@ class NoOpCommunicator : public Communicator { NoOpCommunicator() : Communicator(1, 0) {} bool IsDistributed() const override { return false; } bool IsFederated() const override { return false; } - void AllGather(void *, std::size_t) override {} + std::string AllGather(std::string_view) override { return {}; } + std::string AllGatherV(std::string_view) override { return {}; } void AllReduce(void *, std::size_t, DataType, Operation) override {} void Broadcast(void *, std::size_t, int) override {} - std::string GetProcessorName() override { return ""; } + std::string GetProcessorName() override { return {}; } void Print(const std::string &message) override { LOG(CONSOLE) << message; } protected: diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index 9b79624a2718..59a4bbbd889a 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -7,6 +7,7 @@ #include #include +#include "communicator-inl.h" #include "communicator.h" #include "xgboost/json.h" @@ -55,10 +56,27 @@ class RabitCommunicator : public Communicator { bool IsFederated() const override { return false; } - void AllGather(void *send_receive_buffer, std::size_t size) override { - auto const per_rank = size / GetWorldSize(); + std::string AllGather(std::string_view input) override { + auto const per_rank = input.size(); + auto const total_size = per_rank * GetWorldSize(); auto const index = per_rank * GetRank(); - rabit::Allgather(static_cast(send_receive_buffer), size, index, per_rank, per_rank); + std::string result(total_size, '\0'); + rabit::Allgather(result.data(), total_size, index, per_rank, per_rank); + return result; + } + + std::string AllGatherV(std::string_view input) override { + auto const size_node_slice = input.size(); + auto const all_sizes = collective::Allgather(size_node_slice); + auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul); + auto const begin_index = + std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul); + auto const size_prev_slice = GetRank() == 0 ? 0 : all_sizes[GetRank() - 1]; + + std::string result(total_size, '\0'); + result.replace(begin_index, size_node_slice, input); + rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice); + return result; } void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 1989f68a9444..48e764986966 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -76,10 +76,8 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) { void SimpleDMatrix::ReindexFeatures(Context const* ctx) { if (info_.IsColumnSplit()) { - std::vector buffer(collective::GetWorldSize()); - buffer[collective::GetRank()] = info_.num_col_; - collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t)); - auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0ul); + auto const cols = collective::Allgather(info_.num_col_); + auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul); if (offset == 0) { return; } diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index b4612e24c552..680c50398b42 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -292,20 +292,19 @@ class HistEvaluator { */ std::vector Allgather(std::vector const &entries) { auto const world = collective::GetWorldSize(); - auto const rank = collective::GetRank(); auto const num_entries = entries.size(); // First, gather all the primitive fields. - std::vector all_entries(num_entries * world); + std::vector local_entries(num_entries); std::vector cat_bits; std::vector cat_bits_sizes; for (std::size_t i = 0; i < num_entries; i++) { - all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes); + local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes); } - collective::Allgather(all_entries.data(), all_entries.size() * sizeof(CPUExpandEntry)); + auto all_entries = collective::Allgather(local_entries); // Gather all the cat_bits. - auto gathered = collective::AllgatherV(cat_bits, cat_bits_sizes); + auto gathered = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes); common::ParallelFor(num_entries * world, ctx_->Threads(), [&] (auto i) { // Copy the cat_bits back into all expand entries. @@ -579,28 +578,24 @@ class HistMultiEvaluator { */ std::vector Allgather(std::vector const &entries) { auto const world = collective::GetWorldSize(); - auto const rank = collective::GetRank(); auto const num_entries = entries.size(); // First, gather all the primitive fields. - std::vector all_entries(num_entries * world); + std::vector local_entries(num_entries); std::vector cat_bits; std::vector cat_bits_sizes; std::vector gradients; for (std::size_t i = 0; i < num_entries; i++) { - all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes, - &gradients); + local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes, &gradients); } - collective::Allgather(all_entries.data(), all_entries.size() * sizeof(MultiExpandEntry)); + auto all_entries = collective::Allgather(local_entries); // Gather all the cat_bits. - auto gathered_cat_bits = collective::AllgatherV(cat_bits, cat_bits_sizes); + auto gathered_cat_bits = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes); // Gather all the gradients. auto const num_gradients = gradients.size(); - std::vector all_gradients(num_gradients * world); - std::copy_n(gradients.cbegin(), num_gradients, all_gradients.begin() + num_gradients * rank); - collective::Allgather(all_gradients.data(), all_gradients.size() * sizeof(GradientPairPrecise)); + auto const all_gradients = collective::Allgather(gradients); auto const total_entries = num_entries * world; auto const gradients_per_entry = num_gradients / num_entries; diff --git a/tests/cpp/collective/test_in_memory_communicator.cc b/tests/cpp/collective/test_in_memory_communicator.cc index f36e30e3391d..69c427a4e642 100644 --- a/tests/cpp/collective/test_in_memory_communicator.cc +++ b/tests/cpp/collective/test_in_memory_communicator.cc @@ -29,6 +29,11 @@ class InMemoryCommunicatorTest : public ::testing::Test { VerifyAllgather(comm, rank); } + static void AllgatherV(int rank) { + InMemoryCommunicator comm{kWorldSize, rank}; + VerifyAllgatherV(comm, rank); + } + static void AllreduceMax(int rank) { InMemoryCommunicator comm{kWorldSize, rank}; VerifyAllreduceMax(comm, rank); @@ -80,14 +85,19 @@ class InMemoryCommunicatorTest : public ::testing::Test { protected: static void VerifyAllgather(InMemoryCommunicator &comm, int rank) { - char buffer[kWorldSize] = {'a', 'b', 'c'}; - buffer[rank] = '0' + rank; - comm.AllGather(buffer, kWorldSize); + std::string input{static_cast('0' + rank)}; + auto output = comm.AllGather(input); for (auto i = 0; i < kWorldSize; i++) { - EXPECT_EQ(buffer[i], '0' + i); + EXPECT_EQ(output[i], static_cast('0' + i)); } } + static void VerifyAllgatherV(InMemoryCommunicator &comm, int rank) { + std::vector inputs{"a", "bb", "ccc"}; + auto output = comm.AllGatherV(inputs[rank]); + EXPECT_EQ(output, "abbccc"); + } + static void VerifyAllreduceMax(InMemoryCommunicator &comm, int rank) { int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank}; comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax); @@ -205,6 +215,8 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) { TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); } +TEST_F(InMemoryCommunicatorTest, AllgatherV) { Verify(&AllgatherV); } + TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); } TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); } diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index 20b4afc3026b..b756adefd9b8 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -23,7 +23,7 @@ class ServerForTest { std::unique_ptr server_; public: - explicit ServerForTest(std::int32_t world_size) { + explicit ServerForTest(std::size_t world_size) { server_thread_.reset(new std::thread([this, world_size] { grpc::ServerBuilder builder; xgboost::federated::FederatedService service{world_size}; diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 8b0e1039adff..68b112f1c7b1 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -19,6 +19,11 @@ class FederatedCommunicatorTest : public BaseFederatedTest { CheckAllgather(comm, rank); } + static void VerifyAllgatherV(int rank, const std::string &server_address) { + FederatedCommunicator comm{kWorldSize, rank, server_address}; + CheckAllgatherV(comm, rank); + } + static void VerifyAllreduce(int rank, const std::string &server_address) { FederatedCommunicator comm{kWorldSize, rank, server_address}; CheckAllreduce(comm); @@ -31,14 +36,19 @@ class FederatedCommunicatorTest : public BaseFederatedTest { protected: static void CheckAllgather(FederatedCommunicator &comm, int rank) { - int buffer[kWorldSize] = {0, 0}; - buffer[rank] = rank; - comm.AllGather(buffer, sizeof(buffer)); + std::string input{static_cast('0' + rank)}; + auto output = comm.AllGather(input); for (auto i = 0; i < kWorldSize; i++) { - EXPECT_EQ(buffer[i], i); + EXPECT_EQ(output[i], static_cast('0' + i)); } } + static void CheckAllgatherV(FederatedCommunicator &comm, int rank) { + std::vector inputs{"Federated", " Learning!!!"}; + auto output = comm.AllGatherV(inputs[rank]); + EXPECT_EQ(output, "Federated Learning!!!"); + } + static void CheckAllreduce(FederatedCommunicator &comm) { int buffer[] = {1, 2, 3, 4, 5}; comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); @@ -119,6 +129,16 @@ TEST_F(FederatedCommunicatorTest, Allgather) { } } +TEST_F(FederatedCommunicatorTest, AllgatherV) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgatherV, rank, server_->Address()); + } + for (auto &thread : threads) { + thread.join(); + } +} + TEST_F(FederatedCommunicatorTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index 633d64df10f8..c40e58fa388f 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -18,6 +18,11 @@ class FederatedServerTest : public BaseFederatedTest { CheckAllgather(client, rank); } + static void VerifyAllgatherV(int rank, const std::string& server_address) { + federated::FederatedClient client{server_address, rank}; + CheckAllgatherV(client, rank); + } + static void VerifyAllreduce(int rank, const std::string& server_address) { federated::FederatedClient client{server_address, rank}; CheckAllreduce(client); @@ -39,8 +44,7 @@ class FederatedServerTest : public BaseFederatedTest { protected: static void CheckAllgather(federated::FederatedClient& client, int rank) { - int data[kWorldSize] = {0, 0}; - data[rank] = rank; + int data[] = {rank}; std::string send_buffer(reinterpret_cast(data), sizeof(data)); auto reply = client.Allgather(send_buffer); auto const* result = reinterpret_cast(reply.data()); @@ -49,6 +53,12 @@ class FederatedServerTest : public BaseFederatedTest { } } + static void CheckAllgatherV(federated::FederatedClient& client, int rank) { + std::vector inputs{"Hello,", " World!"}; + auto reply = client.AllgatherV(inputs[rank]); + EXPECT_EQ(reply, "Hello, World!"); + } + static void CheckAllreduce(federated::FederatedClient& client) { int data[] = {1, 2, 3, 4, 5}; std::string send_buffer(reinterpret_cast(data), sizeof(data)); @@ -80,6 +90,16 @@ TEST_F(FederatedServerTest, Allgather) { } } +TEST_F(FederatedServerTest, AllgatherV) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(&FederatedServerTest::VerifyAllgatherV, rank, server_->Address()); + } + for (auto& thread : threads) { + thread.join(); + } +} + TEST_F(FederatedServerTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) {