Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rely on gRPC to generate random port #9102

Merged
merged 1 commit into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 11 additions & 22 deletions tests/cpp/plugin/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,6 @@
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/communicator-inl.h"

inline int GenerateRandomPort(int low, int high) {
using namespace std::chrono_literals;
// Ensure unique timestamp by introducing a small artificial delay
std::this_thread::sleep_for(100ms);
auto timestamp = static_cast<uint64_t>(std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
std::mt19937_64 rng(timestamp);
std::uniform_int_distribution<int> dist(low, high);
int port = dist(rng);
return port;
}

inline std::string GetServerAddress() {
int port = GenerateRandomPort(50000, 60000);
std::string address = std::string("localhost:") + std::to_string(port);
return address;
}

namespace xgboost {

class ServerForTest {
Expand All @@ -41,13 +22,14 @@ class ServerForTest {

public:
explicit ServerForTest(std::int32_t world_size) {
server_address_ = GetServerAddress();
server_thread_.reset(new std::thread([this, world_size] {
grpc::ServerBuilder builder;
xgboost::federated::FederatedService service{world_size};
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
int selected_port;
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
server_address_ = std::string("localhost:") + std::to_string(selected_port);
server_->Wait();
}));
}
Expand All @@ -56,7 +38,14 @@ class ServerForTest {
server_->Shutdown();
server_thread_->join();
}
auto Address() const { return server_address_; }

auto Address() const {
using namespace std::chrono_literals;
while (server_address_.empty()) {
std::this_thread::sleep_for(100ms);
}
return server_address_;
}
};

class BaseFederatedTest : public ::testing::Test {
Expand Down
31 changes: 9 additions & 22 deletions tests/cpp/plugin/test_federated_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,34 +62,24 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
};

TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
std::string server_address{GetServerAddress()};
auto construct = [server_address]() {
FederatedCommunicator comm{0, 0, server_address, "", "", ""};
};
auto construct = [] { FederatedCommunicator comm{0, 0, "localhost:0", "", "", ""}; };
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) {
std::string server_address{GetServerAddress()};
auto construct = [server_address]() {
FederatedCommunicator comm{1, -1, server_address, "", "", ""};
};
auto construct = [] { FederatedCommunicator comm{1, -1, "localhost:0", "", "", ""}; };
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) {
std::string server_address{GetServerAddress()};
auto construct = [server_address]() {
FederatedCommunicator comm{1, 1, server_address, "", "", ""};
};
auto construct = [] { FederatedCommunicator comm{1, 1, "localhost:0", "", "", ""}; };
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
std::string server_address{GetServerAddress()};
auto construct = [server_address]() {
auto construct = [] {
Json config{JsonObject()};
config["federated_server_address"] = server_address;
config["federated_server_address"] = std::string("localhost:0");
config["federated_world_size"] = std::string("1");
config["federated_rank"] = Integer(0);
FederatedCommunicator::Create(config);
Expand All @@ -98,10 +88,9 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
}

TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
std::string server_address{GetServerAddress()};
auto construct = [server_address]() {
auto construct = [] {
Json config{JsonObject()};
config["federated_server_address"] = server_address;
config["federated_server_address"] = std::string("localhost:0");
config["federated_world_size"] = 1;
config["federated_rank"] = std::string("0");
FederatedCommunicator::Create(config);
Expand All @@ -110,15 +99,13 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
}

TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
std::string server_address{GetServerAddress()};
FederatedCommunicator comm{6, 3, server_address};
FederatedCommunicator comm{6, 3, "localhost:0"};
EXPECT_EQ(comm.GetWorldSize(), 6);
EXPECT_EQ(comm.GetRank(), 3);
}

TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
std::string server_address{GetServerAddress()};
FederatedCommunicator comm{2, 1, server_address};
FederatedCommunicator comm{2, 1, "localhost:0"};
EXPECT_TRUE(comm.IsDistributed());
}

Expand Down