Skip to content

Commit

Permalink
squash for rebase
Browse files Browse the repository at this point in the history
[core] Add ClusterID token to GRPC server [1/n] (ray-project#36517)

First of a stack of changes to plumb through token exchange between GCS client and server. This adds a ClusterID token that can be passed to a GRPC server, which then initializes each component GRPC service with the token by passing to the ServerCallFactory objects when they are set up. When the factories create ServerCall objects for the GRPC service completion queue, this token is also passed to the ServerCall to check against inbound request metadata. The actual authentication check does not take place in this PR.

Note: This change also minorly cleans up some code in GCS server (changes a string check to use an enum).

Next change (client-side analogue): ray-project#36526

[core] Generate GCS server token

Signed-off-by: vitsai <victoria@anyscale.com>

Add client-side logic for setting cluster ID.

Signed-off-by: vitsai <victoria@anyscale.com>

bug fixes

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

bug workaround

Signed-off-by: vitsai <victoria@anyscale.com>

Fix windows build

Signed-off-by: vitsai <victoria@anyscale.com>

fix bug

Signed-off-by: vitsai <victoria@anyscale.com>

remove auth stuff from this pr

Signed-off-by: vitsai <victoria@anyscale.com>

fix mock build

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

remove future

Signed-off-by: vitsai <victoria@anyscale.com>

Remove top-level changes

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

Peel back everything that's not grpc-layer changes

Signed-off-by: vitsai <victoria@anyscale.com>

Change atomic to mutex

Signed-off-by: vitsai <victoria@anyscale.com>

Fix alignment of SafeClusterID

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

Add back everything in GCS server except RPC definition

Signed-off-by: vitsai <victoria@anyscale.com>

fix bug

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

Add client-side stuff

Signed-off-by: vitsai <victoria@anyscale.com>

hack workaround to simulate async direct dispatch

love when things hang

Signed-off-by: vitsai <victoria@anyscale.com>
  • Loading branch information
vitsai committed Jul 6, 2023
1 parent a1d59fd commit 1c43c6a
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 13 deletions.
5 changes: 4 additions & 1 deletion src/mock/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ namespace gcs {

class MockGcsClient : public GcsClient {
public:
MOCK_METHOD(Status, Connect, (instrumented_io_context & io_service), (override));
MOCK_METHOD(Status,
Connect,
(instrumented_io_context & io_service, const ClusterID &cluster_id),
(override));
MOCK_METHOD(void, Disconnect, (), (override));
MOCK_METHOD((std::pair<std::string, int>), GetGcsServerAddress, (), (const, override));
MOCK_METHOD(std::string, DebugString, (), (const, override));
Expand Down
21 changes: 20 additions & 1 deletion src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,31 @@ void GcsSubscriberClient::PubsubCommandBatch(
GcsClient::GcsClient(const GcsClientOptions &options, UniqueID gcs_client_id)
: options_(options), gcs_client_id_(gcs_client_id) {}

Status GcsClient::Connect(instrumented_io_context &io_service) {
Status GcsClient::Connect(instrumented_io_context &io_service,
const ClusterID &cluster_id) {
// Connect to gcs service.
client_call_manager_ = std::make_unique<rpc::ClientCallManager>(io_service);
gcs_rpc_client_ = std::make_shared<rpc::GcsRpcClient>(
options_.gcs_address_, options_.gcs_port_, *client_call_manager_);

if (cluster_id.IsNil()) {
rpc::GetClusterIdReply reply;
gcs_rpc_client_->GetClusterId(
rpc::GetClusterIdRequest(),
[this, &io_service](const Status &status, const rpc::GetClusterIdReply &reply) {
RAY_CHECK(status.ok()) << "Failed to get Cluster ID! Status: " << status;
auto cluster_id = ClusterID::FromBinary(reply.cluster_id());
RAY_LOG(DEBUG) << "Setting cluster ID to " << cluster_id;
client_call_manager_->SetClusterId(cluster_id);
io_service.stop();
});
// Run the IO service here to make the above call synchronous.
io_service.run();
io_service.restart();
} else {
client_call_manager_->SetClusterId(cluster_id);
}

resubscribe_func_ = [this]() {
job_accessor_->AsyncResubscribe();
actor_accessor_->AsyncResubscribe();
Expand Down
11 changes: 10 additions & 1 deletion src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <gtest/gtest.h>
#include <gtest/gtest_prod.h>

#include <boost/asio.hpp>
Expand All @@ -35,6 +36,9 @@

namespace ray {

class GcsClientTest;
class GcsClientTest_TestCheckAlive_Test;

namespace gcs {

/// \class GcsClientOptions
Expand Down Expand Up @@ -82,9 +86,11 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
/// Connect to GCS Service. Non-thread safe.
/// This function must be called before calling other functions.
/// \param instrumented_io_context IO execution service.
/// \param cluster_id Optional cluster ID to provide to the client.
///
/// \return Status
virtual Status Connect(instrumented_io_context &io_service);
virtual Status Connect(instrumented_io_context &io_service,
const ClusterID &cluster_id = ClusterID::Nil());

/// Disconnect with GCS Service. Non-thread safe.
virtual void Disconnect();
Expand Down Expand Up @@ -176,6 +182,9 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
std::unique_ptr<InternalKVAccessor> internal_kv_accessor_;
std::unique_ptr<TaskInfoAccessor> task_accessor_;

friend class ray::GcsClientTest;
FRIEND_TEST(ray::GcsClientTest, TestCheckAlive);

private:
const UniqueID gcs_client_id_ = UniqueID::FromRandom();

Expand Down
33 changes: 31 additions & 2 deletions src/ray/gcs/gcs_client/test/gcs_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
rpc::ResetServerCallExecutor();
}

void StampContext(grpc::ClientContext &context) {
RAY_CHECK(gcs_client_->client_call_manager_)
<< "Cannot stamp context before initializing client call manager.";
context.AddMetadata(kClusterIdKey,
gcs_client_->client_call_manager_->cluster_id_.load().Hex());
}

void RestartGcsServer() {
RAY_LOG(INFO) << "Stopping GCS service, port = " << gcs_server_->GetPort();
gcs_server_->Stop();
Expand All @@ -141,11 +148,17 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
grpc::CreateChannel(absl::StrCat("127.0.0.1:", gcs_server_->GetPort()),
grpc::InsecureChannelCredentials());
auto stub = rpc::NodeInfoGcsService::NewStub(std::move(channel));
bool in_memory =
RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage;
grpc::ClientContext context;
if (!in_memory) {
StampContext(context);
}
context.set_deadline(std::chrono::system_clock::now() + 1s);
const rpc::CheckAliveRequest request;
rpc::CheckAliveReply reply;
auto status = stub->CheckAlive(&context, request, &reply);
// If it is in memory, we don't have the new token until we connect again.
if (!status.ok()) {
RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " "
<< status.error_message();
Expand Down Expand Up @@ -315,8 +328,10 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {

bool RegisterNode(const rpc::GcsNodeInfo &node_info) {
std::promise<bool> promise;
RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(
node_info, [&promise](Status status) { promise.set_value(status.ok()); }));
RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(node_info, [&promise](Status status) {
RAY_LOG(INFO) << status;
promise.set_value(status.ok());
}));
return WaitReady(promise.get_future(), timeout_ms_);
}

Expand Down Expand Up @@ -463,6 +478,7 @@ TEST_P(GcsClientTest, TestCheckAlive) {
*(request.mutable_raylet_address()->Add()) = "172.1.2.4:31293";
{
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
rpc::CheckAliveReply reply;
ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok());
Expand All @@ -474,6 +490,7 @@ TEST_P(GcsClientTest, TestCheckAlive) {
ASSERT_TRUE(RegisterNode(*node_info1));
{
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
rpc::CheckAliveReply reply;
ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok());
Expand Down Expand Up @@ -987,9 +1004,21 @@ TEST_P(GcsClientTest, TestEvictExpiredDestroyedActors) {
}
}

TEST_P(GcsClientTest, TestGcsAuth) {
// Restart GCS.
RestartGcsServer();
auto node_info = Mocker::GenNodeInfo();

RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_));
EXPECT_TRUE(RegisterNode(*node_info));
}

TEST_P(GcsClientTest, TestEvictExpiredDeadNodes) {
// Restart GCS.
RestartGcsServer();
if (RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage) {
RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_));
}

// Simulate the scenario of node dead.
int node_count = RayConfig::instance().maximum_gcs_dead_node_cached_count();
Expand Down
4 changes: 3 additions & 1 deletion src/ray/gcs/gcs_client/test/usage_stats_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class UsageStatsClientTest : public ::testing::Test {

TEST_F(UsageStatsClientTest, TestRecordExtraUsageTag) {
gcs::UsageStatsClient usage_stats_client(
"127.0.0.1:" + std::to_string(gcs_server_->GetPort()), *client_io_service_);
"127.0.0.1:" + std::to_string(gcs_server_->GetPort()),
*client_io_service_,
ClusterID::Nil());
usage_stats_client.RecordExtraUsageTag(usage::TagKey::_TEST1, "value1");
ASSERT_TRUE(WaitForCondition(
[this]() {
Expand Down
5 changes: 3 additions & 2 deletions src/ray/gcs/gcs_client/usage_stats_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
namespace ray {
namespace gcs {
UsageStatsClient::UsageStatsClient(const std::string &gcs_address,
instrumented_io_context &io_service) {
instrumented_io_context &io_service,
const ClusterID &cluster_id) {
GcsClientOptions options(gcs_address);
gcs_client_ = std::make_unique<GcsClient>(options);
RAY_CHECK_OK(gcs_client_->Connect(io_service));
RAY_CHECK_OK(gcs_client_->Connect(io_service, cluster_id));
}

void UsageStatsClient::RecordExtraUsageTag(usage::TagKey key, const std::string &value) {
Expand Down
3 changes: 2 additions & 1 deletion src/ray/gcs/gcs_client/usage_stats_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace gcs {
class UsageStatsClient {
public:
explicit UsageStatsClient(const std::string &gcs_address,
instrumented_io_context &io_service);
instrumented_io_context &io_service,
const ClusterID &cluster_id);

/// C++ version of record_extra_usage_tag in usage_lib.py
///
Expand Down
6 changes: 4 additions & 2 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,10 @@ void GcsServer::InitFunctionManager() {
}

void GcsServer::InitUsageStatsClient() {
usage_stats_client_ = std::make_unique<UsageStatsClient>(
"127.0.0.1:" + std::to_string(GetPort()), main_service_);
usage_stats_client_ =
std::make_unique<UsageStatsClient>("127.0.0.1:" + std::to_string(GetPort()),
main_service_,
rpc_server_.GetClusterId());
}

void GcsServer::InitKVManager() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class MockGcsClient : public gcs::GcsClient {
return *node_accessor_;
}

MOCK_METHOD1(Connect, Status(instrumented_io_context &io_service));
MOCK_METHOD2(Connect,
Status(instrumented_io_context &io_service, const ClusterID &cluster_id));

MOCK_METHOD0(Disconnect, void());
};
Expand Down
8 changes: 7 additions & 1 deletion src/ray/rpc/client_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#include "ray/util/util.h"

namespace ray {

class GcsClientTest;
class GcsClientTest_TestCheckAlive_Test;
namespace rpc {

/// Represents an outgoing gRPC request.
Expand Down Expand Up @@ -194,7 +197,7 @@ class ClientCallManager {
const ClusterID &cluster_id = ClusterID::Nil(),
int num_threads = 1,
int64_t call_timeout_ms = -1)
: cluster_id_(ClusterID::Nil()),
: cluster_id_(cluster_id),
main_service_(main_service),
num_threads_(num_threads),
shutdown_(false),
Expand Down Expand Up @@ -278,6 +281,9 @@ class ClientCallManager {
/// Get the main service of this rpc.
instrumented_io_context &GetMainService() { return main_service_; }

friend class ray::GcsClientTest;
FRIEND_TEST(ray::GcsClientTest, TestCheckAlive);

private:
/// This function runs in a background thread. It keeps polling events from the
/// `CompletionQueue`, and dispatches the event to the callbacks via the `ClientCall`
Expand Down

0 comments on commit 1c43c6a

Please sign in to comment.