From 8abe3f39e9448e42ab24d4404d18321f181456d8 Mon Sep 17 00:00:00 2001 From: Victoria Tsai Date: Mon, 10 Jul 2023 13:54:09 -0700 Subject: [PATCH] [core] Retrieve the token from GCS server [4/n] (#37003) Retrieve the token from the GCS server in the GCS client while connecting, to attach to metadata in requests. Previous PR (GCS server): #36535 Next PR (auth): #36073 Signed-off-by: Bhavpreet Singh --- src/mock/ray/gcs/gcs_client/gcs_client.h | 5 ++- src/ray/common/asio/instrumented_io_context.h | 42 ++++++++++++++++++- src/ray/common/id.h | 21 ---------- src/ray/gcs/gcs_client/gcs_client.cc | 34 ++++++++++++++- src/ray/gcs/gcs_client/gcs_client.h | 11 ++++- .../gcs/gcs_client/test/gcs_client_test.cc | 33 ++++++++++++++- .../test/usage_stats_client_test.cc | 4 +- src/ray/gcs/gcs_client/usage_stats_client.cc | 5 ++- src/ray/gcs/gcs_client/usage_stats_client.h | 3 +- src/ray/gcs/gcs_server/gcs_server.cc | 6 ++- .../ownership_based_object_directory_test.cc | 3 +- src/ray/rpc/client_call.h | 22 ++++++---- 12 files changed, 147 insertions(+), 42 deletions(-) diff --git a/src/mock/ray/gcs/gcs_client/gcs_client.h b/src/mock/ray/gcs/gcs_client/gcs_client.h index e7b687d04e7d..cf232a712f51 100644 --- a/src/mock/ray/gcs/gcs_client/gcs_client.h +++ b/src/mock/ray/gcs/gcs_client/gcs_client.h @@ -31,7 +31,10 @@ namespace gcs { class MockGcsClient : public GcsClient { public: - MOCK_METHOD(Status, Connect, (instrumented_io_context & io_service), (override)); + MOCK_METHOD(Status, + Connect, + (instrumented_io_context & io_service, const ClusterID &cluster_id), + (override)); MOCK_METHOD(void, Disconnect, (), (override)); MOCK_METHOD((std::pair), GetGcsServerAddress, (), (const, override)); MOCK_METHOD(std::string, DebugString, (), (const, override)); diff --git a/src/ray/common/asio/instrumented_io_context.h b/src/ray/common/asio/instrumented_io_context.h index aff11c36b9b2..4f9ddfe42fd5 100644 --- a/src/ray/common/asio/instrumented_io_context.h +++ b/src/ray/common/asio/instrumented_io_context.h @@ -28,7 +28,45 @@ class instrumented_io_context : public boost::asio::io_context { public: /// Initializes the global stats struct after calling the base contructor. /// TODO(ekl) allow taking an externally defined event tracker. - instrumented_io_context() : event_stats_(std::make_shared()) {} + instrumented_io_context() + : is_running_{false}, event_stats_(std::make_shared()) {} + + /// Run the io_context if and only if no other thread is running it. Blocks + /// other threads from running once started. Noop if there is a thread + /// running it already. Only used in GcsClient::Connect, to be deprecated + /// after the introduction of executors. + void run_if_not_running(std::function prerun_fn) { + absl::MutexLock l(&mu_); + // Note: this doesn't set is_running_ because it blocks anything else from + // running anyway. + if (!is_running_) { + prerun_fn(); + is_running_ = true; + boost::asio::io_context::run(); + } + } + + /// Assumes the mutex is held. Undefined behavior if not. + void stop_without_lock() { + is_running_ = false; + boost::asio::io_context::stop(); + } + + void run() { + { + absl::MutexLock l(&mu_); + is_running_ = true; + } + boost::asio::io_context::run(); + } + + void stop() { + { + absl::MutexLock l(&mu_); + is_running_ = false; + } + boost::asio::io_context::stop(); + } /// A proxy post function that collects count, queueing, and execution statistics for /// the given handler. @@ -56,6 +94,8 @@ class instrumented_io_context : public boost::asio::io_context { EventTracker &stats() const { return *event_stats_; }; private: + absl::Mutex mu_; + bool is_running_; /// The event stats tracker to use to record asio handler stats to. std::shared_ptr event_stats_; }; diff --git a/src/ray/common/id.h b/src/ray/common/id.h index 7c5f430830ab..a6c753a1de35 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -414,27 +414,6 @@ std::ostream &operator<<(std::ostream &os, const PlacementGroupID &id); // Restore the compiler alignment to default (8 bytes). #pragma pack(pop) -struct SafeClusterID { - private: - mutable absl::Mutex m_; - ClusterID id_ GUARDED_BY(m_); - - public: - SafeClusterID(const ClusterID &id) : id_(id) {} - - const ClusterID load() const { - absl::MutexLock l(&m_); - return id_; - } - - ClusterID exchange(const ClusterID &newId) { - absl::MutexLock l(&m_); - ClusterID old = id_; - id_ = newId; - return old; - } -}; - template BaseID::BaseID() { // Using const_cast to directly change data is dangerous. The cached diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index 250034c028d5..7cc025b039e3 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -81,12 +81,44 @@ void GcsSubscriberClient::PubsubCommandBatch( GcsClient::GcsClient(const GcsClientOptions &options, UniqueID gcs_client_id) : options_(options), gcs_client_id_(gcs_client_id) {} -Status GcsClient::Connect(instrumented_io_context &io_service) { +Status GcsClient::Connect(instrumented_io_context &io_service, + const ClusterID &cluster_id) { // Connect to gcs service. client_call_manager_ = std::make_unique(io_service); gcs_rpc_client_ = std::make_shared( options_.gcs_address_, options_.gcs_port_, *client_call_manager_); + if (cluster_id.IsNil()) { + rpc::GetClusterIdReply reply; + std::promise cluster_known; + std::atomic stop_io_service{false}; + gcs_rpc_client_->GetClusterId( + rpc::GetClusterIdRequest(), + [this, &cluster_known, &stop_io_service, &io_service]( + const Status &status, const rpc::GetClusterIdReply &reply) { + RAY_CHECK(status.ok()) << "Failed to get Cluster ID! Status: " << status; + auto cluster_id = ClusterID::FromBinary(reply.cluster_id()); + RAY_LOG(DEBUG) << "Setting cluster ID to " << cluster_id; + client_call_manager_->SetClusterId(cluster_id); + if (stop_io_service.load()) { + io_service.stop_without_lock(); + } else { + cluster_known.set_value(); + } + }); + // Run the IO service here to make the above call synchronous. + // If it is already running, then wait for our particular callback + // to be processed. + io_service.run_if_not_running([&stop_io_service]() { stop_io_service.store(true); }); + if (stop_io_service.load()) { + io_service.restart(); + } else { + cluster_known.get_future().get(); + } + } else { + client_call_manager_->SetClusterId(cluster_id); + } + resubscribe_func_ = [this]() { job_accessor_->AsyncResubscribe(); actor_accessor_->AsyncResubscribe(); diff --git a/src/ray/gcs/gcs_client/gcs_client.h b/src/ray/gcs/gcs_client/gcs_client.h index b0d4153fe8da..ea61d64a61b5 100644 --- a/src/ray/gcs/gcs_client/gcs_client.h +++ b/src/ray/gcs/gcs_client/gcs_client.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -35,6 +36,9 @@ namespace ray { +class GcsClientTest; +class GcsClientTest_TestCheckAlive_Test; + namespace gcs { /// \class GcsClientOptions @@ -82,9 +86,11 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this { /// Connect to GCS Service. Non-thread safe. /// This function must be called before calling other functions. /// \param instrumented_io_context IO execution service. + /// \param cluster_id Optional cluster ID to provide to the client. /// /// \return Status - virtual Status Connect(instrumented_io_context &io_service); + virtual Status Connect(instrumented_io_context &io_service, + const ClusterID &cluster_id = ClusterID::Nil()); /// Disconnect with GCS Service. Non-thread safe. virtual void Disconnect(); @@ -176,6 +182,9 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this { std::unique_ptr internal_kv_accessor_; std::unique_ptr task_accessor_; + friend class ray::GcsClientTest; + FRIEND_TEST(ray::GcsClientTest, TestCheckAlive); + private: const UniqueID gcs_client_id_ = UniqueID::FromRandom(); diff --git a/src/ray/gcs/gcs_client/test/gcs_client_test.cc b/src/ray/gcs/gcs_client/test/gcs_client_test.cc index afbf08b6f7a0..68b3eb3c3b04 100644 --- a/src/ray/gcs/gcs_client/test/gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/gcs_client_test.cc @@ -115,6 +115,13 @@ class GcsClientTest : public ::testing::TestWithParam { rpc::ResetServerCallExecutor(); } + void StampContext(grpc::ClientContext &context) { + RAY_CHECK(gcs_client_->client_call_manager_) + << "Cannot stamp context before initializing client call manager."; + context.AddMetadata(kClusterIdKey, + gcs_client_->client_call_manager_->cluster_id_.Hex()); + } + void RestartGcsServer() { RAY_LOG(INFO) << "Stopping GCS service, port = " << gcs_server_->GetPort(); gcs_server_->Stop(); @@ -141,11 +148,17 @@ class GcsClientTest : public ::testing::TestWithParam { grpc::CreateChannel(absl::StrCat("127.0.0.1:", gcs_server_->GetPort()), grpc::InsecureChannelCredentials()); auto stub = rpc::NodeInfoGcsService::NewStub(std::move(channel)); + bool in_memory = + RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage; grpc::ClientContext context; + if (!in_memory) { + StampContext(context); + } context.set_deadline(std::chrono::system_clock::now() + 1s); const rpc::CheckAliveRequest request; rpc::CheckAliveReply reply; auto status = stub->CheckAlive(&context, request, &reply); + // If it is in memory, we don't have the new token until we connect again. if (!status.ok()) { RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " " << status.error_message(); @@ -315,8 +328,10 @@ class GcsClientTest : public ::testing::TestWithParam { bool RegisterNode(const rpc::GcsNodeInfo &node_info) { std::promise promise; - RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister( - node_info, [&promise](Status status) { promise.set_value(status.ok()); })); + RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(node_info, [&promise](Status status) { + RAY_LOG(INFO) << status; + promise.set_value(status.ok()); + })); return WaitReady(promise.get_future(), timeout_ms_); } @@ -463,6 +478,7 @@ TEST_P(GcsClientTest, TestCheckAlive) { *(request.mutable_raylet_address()->Add()) = "172.1.2.4:31293"; { grpc::ClientContext context; + StampContext(context); context.set_deadline(std::chrono::system_clock::now() + 1s); rpc::CheckAliveReply reply; ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok()); @@ -474,6 +490,7 @@ TEST_P(GcsClientTest, TestCheckAlive) { ASSERT_TRUE(RegisterNode(*node_info1)); { grpc::ClientContext context; + StampContext(context); context.set_deadline(std::chrono::system_clock::now() + 1s); rpc::CheckAliveReply reply; ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok()); @@ -987,9 +1004,21 @@ TEST_P(GcsClientTest, TestEvictExpiredDestroyedActors) { } } +TEST_P(GcsClientTest, TestGcsAuth) { + // Restart GCS. + RestartGcsServer(); + auto node_info = Mocker::GenNodeInfo(); + + RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_)); + EXPECT_TRUE(RegisterNode(*node_info)); +} + TEST_P(GcsClientTest, TestEvictExpiredDeadNodes) { // Restart GCS. RestartGcsServer(); + if (RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage) { + RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_)); + } // Simulate the scenario of node dead. int node_count = RayConfig::instance().maximum_gcs_dead_node_cached_count(); diff --git a/src/ray/gcs/gcs_client/test/usage_stats_client_test.cc b/src/ray/gcs/gcs_client/test/usage_stats_client_test.cc index 6049506c790f..55efa98258f7 100644 --- a/src/ray/gcs/gcs_client/test/usage_stats_client_test.cc +++ b/src/ray/gcs/gcs_client/test/usage_stats_client_test.cc @@ -82,7 +82,9 @@ class UsageStatsClientTest : public ::testing::Test { TEST_F(UsageStatsClientTest, TestRecordExtraUsageTag) { gcs::UsageStatsClient usage_stats_client( - "127.0.0.1:" + std::to_string(gcs_server_->GetPort()), *client_io_service_); + "127.0.0.1:" + std::to_string(gcs_server_->GetPort()), + *client_io_service_, + ClusterID::Nil()); usage_stats_client.RecordExtraUsageTag(usage::TagKey::_TEST1, "value1"); ASSERT_TRUE(WaitForCondition( [this]() { diff --git a/src/ray/gcs/gcs_client/usage_stats_client.cc b/src/ray/gcs/gcs_client/usage_stats_client.cc index 4a538d606390..41104b3abdbc 100644 --- a/src/ray/gcs/gcs_client/usage_stats_client.cc +++ b/src/ray/gcs/gcs_client/usage_stats_client.cc @@ -17,10 +17,11 @@ namespace ray { namespace gcs { UsageStatsClient::UsageStatsClient(const std::string &gcs_address, - instrumented_io_context &io_service) { + instrumented_io_context &io_service, + const ClusterID &cluster_id) { GcsClientOptions options(gcs_address); gcs_client_ = std::make_unique(options); - RAY_CHECK_OK(gcs_client_->Connect(io_service)); + RAY_CHECK_OK(gcs_client_->Connect(io_service, cluster_id)); } void UsageStatsClient::RecordExtraUsageTag(usage::TagKey key, const std::string &value) { diff --git a/src/ray/gcs/gcs_client/usage_stats_client.h b/src/ray/gcs/gcs_client/usage_stats_client.h index ccda6d247bec..4b825f1c6530 100644 --- a/src/ray/gcs/gcs_client/usage_stats_client.h +++ b/src/ray/gcs/gcs_client/usage_stats_client.h @@ -24,7 +24,8 @@ namespace gcs { class UsageStatsClient { public: explicit UsageStatsClient(const std::string &gcs_address, - instrumented_io_context &io_service); + instrumented_io_context &io_service, + const ClusterID &cluster_id); /// C++ version of record_extra_usage_tag in usage_lib.py /// diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index fbbec8df6daf..66ed3575775b 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -566,8 +566,10 @@ void GcsServer::InitFunctionManager() { } void GcsServer::InitUsageStatsClient() { - usage_stats_client_ = std::make_unique( - "127.0.0.1:" + std::to_string(GetPort()), main_service_); + usage_stats_client_ = + std::make_unique("127.0.0.1:" + std::to_string(GetPort()), + main_service_, + rpc_server_.GetClusterId()); } void GcsServer::InitKVManager() { diff --git a/src/ray/object_manager/test/ownership_based_object_directory_test.cc b/src/ray/object_manager/test/ownership_based_object_directory_test.cc index 3b7624a604a3..a326b2c9eb65 100644 --- a/src/ray/object_manager/test/ownership_based_object_directory_test.cc +++ b/src/ray/object_manager/test/ownership_based_object_directory_test.cc @@ -101,7 +101,8 @@ class MockGcsClient : public gcs::GcsClient { return *node_accessor_; } - MOCK_METHOD1(Connect, Status(instrumented_io_context &io_service)); + MOCK_METHOD2(Connect, + Status(instrumented_io_context &io_service, const ClusterID &cluster_id)); MOCK_METHOD0(Disconnect, void()); }; diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 98c1e519d3c5..a9c52fc3717e 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -28,6 +28,9 @@ #include "ray/util/util.h" namespace ray { + +class GcsClientTest; +class GcsClientTest_TestCheckAlive_Test; namespace rpc { /// Represents an outgoing gRPC request. @@ -145,10 +148,10 @@ class ClientCallImpl : public ClientCall { /// The lifecycle of a `ClientCallTag` is as follows. /// /// When a client submits a new gRPC request, a new `ClientCallTag` object will be created -/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of +/// by `ClientCallManager::CreateCall`. Then the object will be used as the tag of /// `CompletionQueue`. /// -/// When the reply is received, `ClientCallMangager` will get the address of this object +/// When the reply is received, `ClientCallManager` will get the address of this object /// via `CompletionQueue`'s tag. And the manager should call /// `GetCall()->OnReplyReceived()` and then delete this object. class ClientCallTag { @@ -194,7 +197,7 @@ class ClientCallManager { const ClusterID &cluster_id = ClusterID::Nil(), int num_threads = 1, int64_t call_timeout_ms = -1) - : cluster_id_(ClusterID::Nil()), + : cluster_id_(cluster_id), main_service_(main_service), num_threads_(num_threads), shutdown_(false), @@ -249,7 +252,7 @@ class ClientCallManager { } auto call = std::make_shared>( - callback, cluster_id_.load(), std::move(stats_handle), method_timeout_ms); + callback, cluster_id_, std::move(stats_handle), method_timeout_ms); // Send request. // Find the next completion queue to wait for response. call->response_reader_ = (stub.*prepare_async_function)( @@ -268,16 +271,19 @@ class ClientCallManager { } void SetClusterId(const ClusterID &cluster_id) { - auto old_id = cluster_id_.exchange(ClusterID::Nil()); - if (!old_id.IsNil() && (old_id != cluster_id)) { + if (!cluster_id_.IsNil() && (cluster_id_ != cluster_id)) { RAY_LOG(FATAL) << "Expected cluster ID to be Nil or " << cluster_id << ", but got" - << old_id; + << cluster_id_; } + cluster_id_ = cluster_id; } /// Get the main service of this rpc. instrumented_io_context &GetMainService() { return main_service_; } + friend class ray::GcsClientTest; + FRIEND_TEST(ray::GcsClientTest, TestCheckAlive); + private: /// This function runs in a background thread. It keeps polling events from the /// `CompletionQueue`, and dispatches the event to the callbacks via the `ClientCall` @@ -328,7 +334,7 @@ class ClientCallManager { /// UUID of the cluster. Potential race between creating a ClientCall object /// and setting the cluster ID. - SafeClusterID cluster_id_; + ClusterID cluster_id_; /// The main event loop, to which the callback functions will be posted. instrumented_io_context &main_service_;