Skip to content

Commit

Permalink
[core] Retrieve the token from GCS server [4/n] (#37003)
Browse files Browse the repository at this point in the history
Retrieve the token from the GCS server in the GCS client while connecting, to attach to metadata in requests.

Previous PR (GCS server): #36535
Next PR (auth): #36073
  • Loading branch information
vitsai authored Jul 10, 2023
1 parent 1334be9 commit 4d8c2c9
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 42 deletions.
5 changes: 4 additions & 1 deletion src/mock/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ namespace gcs {

class MockGcsClient : public GcsClient {
public:
MOCK_METHOD(Status, Connect, (instrumented_io_context & io_service), (override));
MOCK_METHOD(Status,
Connect,
(instrumented_io_context & io_service, const ClusterID &cluster_id),
(override));
MOCK_METHOD(void, Disconnect, (), (override));
MOCK_METHOD((std::pair<std::string, int>), GetGcsServerAddress, (), (const, override));
MOCK_METHOD(std::string, DebugString, (), (const, override));
Expand Down
42 changes: 41 additions & 1 deletion src/ray/common/asio/instrumented_io_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,45 @@ class instrumented_io_context : public boost::asio::io_context {
public:
/// Initializes the global stats struct after calling the base contructor.
/// TODO(ekl) allow taking an externally defined event tracker.
instrumented_io_context() : event_stats_(std::make_shared<EventTracker>()) {}
instrumented_io_context()
: is_running_{false}, event_stats_(std::make_shared<EventTracker>()) {}

/// Run the io_context if and only if no other thread is running it. Blocks
/// other threads from running once started. Noop if there is a thread
/// running it already. Only used in GcsClient::Connect, to be deprecated
/// after the introduction of executors.
void run_if_not_running(std::function<void()> prerun_fn) {
absl::MutexLock l(&mu_);
// Note: this doesn't set is_running_ because it blocks anything else from
// running anyway.
if (!is_running_) {
prerun_fn();
is_running_ = true;
boost::asio::io_context::run();
}
}

/// Assumes the mutex is held. Undefined behavior if not.
void stop_without_lock() {
is_running_ = false;
boost::asio::io_context::stop();
}

void run() {
{
absl::MutexLock l(&mu_);
is_running_ = true;
}
boost::asio::io_context::run();
}

void stop() {
{
absl::MutexLock l(&mu_);
is_running_ = false;
}
boost::asio::io_context::stop();
}

/// A proxy post function that collects count, queueing, and execution statistics for
/// the given handler.
Expand Down Expand Up @@ -56,6 +94,8 @@ class instrumented_io_context : public boost::asio::io_context {
EventTracker &stats() const { return *event_stats_; };

private:
absl::Mutex mu_;
bool is_running_;
/// The event stats tracker to use to record asio handler stats to.
std::shared_ptr<EventTracker> event_stats_;
};
21 changes: 0 additions & 21 deletions src/ray/common/id.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,27 +414,6 @@ std::ostream &operator<<(std::ostream &os, const PlacementGroupID &id);
// Restore the compiler alignment to default (8 bytes).
#pragma pack(pop)

struct SafeClusterID {
private:
mutable absl::Mutex m_;
ClusterID id_ GUARDED_BY(m_);

public:
SafeClusterID(const ClusterID &id) : id_(id) {}

const ClusterID load() const {
absl::MutexLock l(&m_);
return id_;
}

ClusterID exchange(const ClusterID &newId) {
absl::MutexLock l(&m_);
ClusterID old = id_;
id_ = newId;
return old;
}
};

template <typename T>
BaseID<T>::BaseID() {
// Using const_cast to directly change data is dangerous. The cached
Expand Down
34 changes: 33 additions & 1 deletion src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,44 @@ void GcsSubscriberClient::PubsubCommandBatch(
GcsClient::GcsClient(const GcsClientOptions &options, UniqueID gcs_client_id)
: options_(options), gcs_client_id_(gcs_client_id) {}

Status GcsClient::Connect(instrumented_io_context &io_service) {
Status GcsClient::Connect(instrumented_io_context &io_service,
const ClusterID &cluster_id) {
// Connect to gcs service.
client_call_manager_ = std::make_unique<rpc::ClientCallManager>(io_service);
gcs_rpc_client_ = std::make_shared<rpc::GcsRpcClient>(
options_.gcs_address_, options_.gcs_port_, *client_call_manager_);

if (cluster_id.IsNil()) {
rpc::GetClusterIdReply reply;
std::promise<void> cluster_known;
std::atomic<bool> stop_io_service{false};
gcs_rpc_client_->GetClusterId(
rpc::GetClusterIdRequest(),
[this, &cluster_known, &stop_io_service, &io_service](
const Status &status, const rpc::GetClusterIdReply &reply) {
RAY_CHECK(status.ok()) << "Failed to get Cluster ID! Status: " << status;
auto cluster_id = ClusterID::FromBinary(reply.cluster_id());
RAY_LOG(DEBUG) << "Setting cluster ID to " << cluster_id;
client_call_manager_->SetClusterId(cluster_id);
if (stop_io_service.load()) {
io_service.stop_without_lock();
} else {
cluster_known.set_value();
}
});
// Run the IO service here to make the above call synchronous.
// If it is already running, then wait for our particular callback
// to be processed.
io_service.run_if_not_running([&stop_io_service]() { stop_io_service.store(true); });
if (stop_io_service.load()) {
io_service.restart();
} else {
cluster_known.get_future().get();
}
} else {
client_call_manager_->SetClusterId(cluster_id);
}

resubscribe_func_ = [this]() {
job_accessor_->AsyncResubscribe();
actor_accessor_->AsyncResubscribe();
Expand Down
11 changes: 10 additions & 1 deletion src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <gtest/gtest.h>
#include <gtest/gtest_prod.h>

#include <boost/asio.hpp>
Expand All @@ -35,6 +36,9 @@

namespace ray {

class GcsClientTest;
class GcsClientTest_TestCheckAlive_Test;

namespace gcs {

/// \class GcsClientOptions
Expand Down Expand Up @@ -82,9 +86,11 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
/// Connect to GCS Service. Non-thread safe.
/// This function must be called before calling other functions.
/// \param instrumented_io_context IO execution service.
/// \param cluster_id Optional cluster ID to provide to the client.
///
/// \return Status
virtual Status Connect(instrumented_io_context &io_service);
virtual Status Connect(instrumented_io_context &io_service,
const ClusterID &cluster_id = ClusterID::Nil());

/// Disconnect with GCS Service. Non-thread safe.
virtual void Disconnect();
Expand Down Expand Up @@ -176,6 +182,9 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
std::unique_ptr<InternalKVAccessor> internal_kv_accessor_;
std::unique_ptr<TaskInfoAccessor> task_accessor_;

friend class ray::GcsClientTest;
FRIEND_TEST(ray::GcsClientTest, TestCheckAlive);

private:
const UniqueID gcs_client_id_ = UniqueID::FromRandom();

Expand Down
33 changes: 31 additions & 2 deletions src/ray/gcs/gcs_client/test/gcs_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
rpc::ResetServerCallExecutor();
}

void StampContext(grpc::ClientContext &context) {
RAY_CHECK(gcs_client_->client_call_manager_)
<< "Cannot stamp context before initializing client call manager.";
context.AddMetadata(kClusterIdKey,
gcs_client_->client_call_manager_->cluster_id_.Hex());
}

void RestartGcsServer() {
RAY_LOG(INFO) << "Stopping GCS service, port = " << gcs_server_->GetPort();
gcs_server_->Stop();
Expand All @@ -141,11 +148,17 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
grpc::CreateChannel(absl::StrCat("127.0.0.1:", gcs_server_->GetPort()),
grpc::InsecureChannelCredentials());
auto stub = rpc::NodeInfoGcsService::NewStub(std::move(channel));
bool in_memory =
RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage;
grpc::ClientContext context;
if (!in_memory) {
StampContext(context);
}
context.set_deadline(std::chrono::system_clock::now() + 1s);
const rpc::CheckAliveRequest request;
rpc::CheckAliveReply reply;
auto status = stub->CheckAlive(&context, request, &reply);
// If it is in memory, we don't have the new token until we connect again.
if (!status.ok()) {
RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " "
<< status.error_message();
Expand Down Expand Up @@ -315,8 +328,10 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {

bool RegisterNode(const rpc::GcsNodeInfo &node_info) {
std::promise<bool> promise;
RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(
node_info, [&promise](Status status) { promise.set_value(status.ok()); }));
RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(node_info, [&promise](Status status) {
RAY_LOG(INFO) << status;
promise.set_value(status.ok());
}));
return WaitReady(promise.get_future(), timeout_ms_);
}

Expand Down Expand Up @@ -463,6 +478,7 @@ TEST_P(GcsClientTest, TestCheckAlive) {
*(request.mutable_raylet_address()->Add()) = "172.1.2.4:31293";
{
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
rpc::CheckAliveReply reply;
ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok());
Expand All @@ -474,6 +490,7 @@ TEST_P(GcsClientTest, TestCheckAlive) {
ASSERT_TRUE(RegisterNode(*node_info1));
{
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
rpc::CheckAliveReply reply;
ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok());
Expand Down Expand Up @@ -987,9 +1004,21 @@ TEST_P(GcsClientTest, TestEvictExpiredDestroyedActors) {
}
}

TEST_P(GcsClientTest, TestGcsAuth) {
// Restart GCS.
RestartGcsServer();
auto node_info = Mocker::GenNodeInfo();

RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_));
EXPECT_TRUE(RegisterNode(*node_info));
}

TEST_P(GcsClientTest, TestEvictExpiredDeadNodes) {
// Restart GCS.
RestartGcsServer();
if (RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage) {
RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_));
}

// Simulate the scenario of node dead.
int node_count = RayConfig::instance().maximum_gcs_dead_node_cached_count();
Expand Down
4 changes: 3 additions & 1 deletion src/ray/gcs/gcs_client/test/usage_stats_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class UsageStatsClientTest : public ::testing::Test {

TEST_F(UsageStatsClientTest, TestRecordExtraUsageTag) {
gcs::UsageStatsClient usage_stats_client(
"127.0.0.1:" + std::to_string(gcs_server_->GetPort()), *client_io_service_);
"127.0.0.1:" + std::to_string(gcs_server_->GetPort()),
*client_io_service_,
ClusterID::Nil());
usage_stats_client.RecordExtraUsageTag(usage::TagKey::_TEST1, "value1");
ASSERT_TRUE(WaitForCondition(
[this]() {
Expand Down
5 changes: 3 additions & 2 deletions src/ray/gcs/gcs_client/usage_stats_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
namespace ray {
namespace gcs {
UsageStatsClient::UsageStatsClient(const std::string &gcs_address,
instrumented_io_context &io_service) {
instrumented_io_context &io_service,
const ClusterID &cluster_id) {
GcsClientOptions options(gcs_address);
gcs_client_ = std::make_unique<GcsClient>(options);
RAY_CHECK_OK(gcs_client_->Connect(io_service));
RAY_CHECK_OK(gcs_client_->Connect(io_service, cluster_id));
}

void UsageStatsClient::RecordExtraUsageTag(usage::TagKey key, const std::string &value) {
Expand Down
3 changes: 2 additions & 1 deletion src/ray/gcs/gcs_client/usage_stats_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace gcs {
class UsageStatsClient {
public:
explicit UsageStatsClient(const std::string &gcs_address,
instrumented_io_context &io_service);
instrumented_io_context &io_service,
const ClusterID &cluster_id);

/// C++ version of record_extra_usage_tag in usage_lib.py
///
Expand Down
6 changes: 4 additions & 2 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,10 @@ void GcsServer::InitFunctionManager() {
}

void GcsServer::InitUsageStatsClient() {
usage_stats_client_ = std::make_unique<UsageStatsClient>(
"127.0.0.1:" + std::to_string(GetPort()), main_service_);
usage_stats_client_ =
std::make_unique<UsageStatsClient>("127.0.0.1:" + std::to_string(GetPort()),
main_service_,
rpc_server_.GetClusterId());
}

void GcsServer::InitKVManager() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class MockGcsClient : public gcs::GcsClient {
return *node_accessor_;
}

MOCK_METHOD1(Connect, Status(instrumented_io_context &io_service));
MOCK_METHOD2(Connect,
Status(instrumented_io_context &io_service, const ClusterID &cluster_id));

MOCK_METHOD0(Disconnect, void());
};
Expand Down
22 changes: 14 additions & 8 deletions src/ray/rpc/client_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#include "ray/util/util.h"

namespace ray {

class GcsClientTest;
class GcsClientTest_TestCheckAlive_Test;
namespace rpc {

/// Represents an outgoing gRPC request.
Expand Down Expand Up @@ -145,10 +148,10 @@ class ClientCallImpl : public ClientCall {
/// The lifecycle of a `ClientCallTag` is as follows.
///
/// When a client submits a new gRPC request, a new `ClientCallTag` object will be created
/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of
/// by `ClientCallManager::CreateCall`. Then the object will be used as the tag of
/// `CompletionQueue`.
///
/// When the reply is received, `ClientCallMangager` will get the address of this object
/// When the reply is received, `ClientCallManager` will get the address of this object
/// via `CompletionQueue`'s tag. And the manager should call
/// `GetCall()->OnReplyReceived()` and then delete this object.
class ClientCallTag {
Expand Down Expand Up @@ -194,7 +197,7 @@ class ClientCallManager {
const ClusterID &cluster_id = ClusterID::Nil(),
int num_threads = 1,
int64_t call_timeout_ms = -1)
: cluster_id_(ClusterID::Nil()),
: cluster_id_(cluster_id),
main_service_(main_service),
num_threads_(num_threads),
shutdown_(false),
Expand Down Expand Up @@ -249,7 +252,7 @@ class ClientCallManager {
}

auto call = std::make_shared<ClientCallImpl<Reply>>(
callback, cluster_id_.load(), std::move(stats_handle), method_timeout_ms);
callback, cluster_id_, std::move(stats_handle), method_timeout_ms);
// Send request.
// Find the next completion queue to wait for response.
call->response_reader_ = (stub.*prepare_async_function)(
Expand All @@ -268,16 +271,19 @@ class ClientCallManager {
}

void SetClusterId(const ClusterID &cluster_id) {
auto old_id = cluster_id_.exchange(ClusterID::Nil());
if (!old_id.IsNil() && (old_id != cluster_id)) {
if (!cluster_id_.IsNil() && (cluster_id_ != cluster_id)) {
RAY_LOG(FATAL) << "Expected cluster ID to be Nil or " << cluster_id << ", but got"
<< old_id;
<< cluster_id_;
}
cluster_id_ = cluster_id;
}

/// Get the main service of this rpc.
instrumented_io_context &GetMainService() { return main_service_; }

friend class ray::GcsClientTest;
FRIEND_TEST(ray::GcsClientTest, TestCheckAlive);

private:
/// This function runs in a background thread. It keeps polling events from the
/// `CompletionQueue`, and dispatches the event to the callbacks via the `ClientCall`
Expand Down Expand Up @@ -328,7 +334,7 @@ class ClientCallManager {

/// UUID of the cluster. Potential race between creating a ClientCall object
/// and setting the cluster ID.
SafeClusterID cluster_id_;
ClusterID cluster_id_;

/// The main event loop, to which the callback functions will be posted.
instrumented_io_context &main_service_;
Expand Down

0 comments on commit 4d8c2c9

Please sign in to comment.