From 511d4996b56424b51e9fef4ca6e7e71f48ca237b Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 26 Apr 2023 18:48:26 -0700 Subject: [PATCH] Rely on gRPC to generate random port (#9102) --- tests/cpp/plugin/helpers.h | 33 +++++++------------ .../cpp/plugin/test_federated_communicator.cc | 31 +++++------------ 2 files changed, 20 insertions(+), 44 deletions(-) diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index 41e5a63e553f..0dbdeeca416f 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -13,25 +13,6 @@ #include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" -inline int GenerateRandomPort(int low, int high) { - using namespace std::chrono_literals; - // Ensure unique timestamp by introducing a small artificial delay - std::this_thread::sleep_for(100ms); - auto timestamp = static_cast(std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - std::mt19937_64 rng(timestamp); - std::uniform_int_distribution dist(low, high); - int port = dist(rng); - return port; -} - -inline std::string GetServerAddress() { - int port = GenerateRandomPort(50000, 60000); - std::string address = std::string("localhost:") + std::to_string(port); - return address; -} - namespace xgboost { class ServerForTest { @@ -41,13 +22,14 @@ class ServerForTest { public: explicit ServerForTest(std::int32_t world_size) { - server_address_ = GetServerAddress(); server_thread_.reset(new std::thread([this, world_size] { grpc::ServerBuilder builder; xgboost::federated::FederatedService service{world_size}; - builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); + int selected_port; + builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port); builder.RegisterService(&service); server_ = builder.BuildAndStart(); + server_address_ = std::string("localhost:") + std::to_string(selected_port); server_->Wait(); })); } @@ -56,7 +38,14 @@ class ServerForTest { server_->Shutdown(); server_thread_->join(); } - auto Address() const { return server_address_; } + + auto Address() const { + using namespace std::chrono_literals; + while (server_address_.empty()) { + std::this_thread::sleep_for(100ms); + } + return server_address_; + } }; class BaseFederatedTest : public ::testing::Test { diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 340849606256..62f33d5ee29a 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -62,34 +62,24 @@ class FederatedCommunicatorTest : public BaseFederatedTest { }; TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { - std::string server_address{GetServerAddress()}; - auto construct = [server_address]() { - FederatedCommunicator comm{0, 0, server_address, "", "", ""}; - }; + auto construct = [] { FederatedCommunicator comm{0, 0, "localhost:0", "", "", ""}; }; EXPECT_THROW(construct(), dmlc::Error); } TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) { - std::string server_address{GetServerAddress()}; - auto construct = [server_address]() { - FederatedCommunicator comm{1, -1, server_address, "", "", ""}; - }; + auto construct = [] { FederatedCommunicator comm{1, -1, "localhost:0", "", "", ""}; }; EXPECT_THROW(construct(), dmlc::Error); } TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) { - std::string server_address{GetServerAddress()}; - auto construct = [server_address]() { - FederatedCommunicator comm{1, 1, server_address, "", "", ""}; - }; + auto construct = [] { FederatedCommunicator comm{1, 1, "localhost:0", "", "", ""}; }; EXPECT_THROW(construct(), dmlc::Error); } TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { - std::string server_address{GetServerAddress()}; - auto construct = [server_address]() { + auto construct = [] { Json config{JsonObject()}; - config["federated_server_address"] = server_address; + config["federated_server_address"] = std::string("localhost:0"); config["federated_world_size"] = std::string("1"); config["federated_rank"] = Integer(0); FederatedCommunicator::Create(config); @@ -98,10 +88,9 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { } TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { - std::string server_address{GetServerAddress()}; - auto construct = [server_address]() { + auto construct = [] { Json config{JsonObject()}; - config["federated_server_address"] = server_address; + config["federated_server_address"] = std::string("localhost:0"); config["federated_world_size"] = 1; config["federated_rank"] = std::string("0"); FederatedCommunicator::Create(config); @@ -110,15 +99,13 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { } TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) { - std::string server_address{GetServerAddress()}; - FederatedCommunicator comm{6, 3, server_address}; + FederatedCommunicator comm{6, 3, "localhost:0"}; EXPECT_EQ(comm.GetWorldSize(), 6); EXPECT_EQ(comm.GetRank(), 3); } TEST(FederatedCommunicatorSimpleTest, IsDistributed) { - std::string server_address{GetServerAddress()}; - FederatedCommunicator comm{2, 1, server_address}; + FederatedCommunicator comm{2, 1, "localhost:0"}; EXPECT_TRUE(comm.IsDistributed()); }