Skip to content

Commit

Permalink
[core] Add ClusterID token to GRPC server [1/n] (#36517)
Browse files Browse the repository at this point in the history
First of a stack of changes to plumb through token exchange between GCS client and server. This adds a ClusterID token that can be passed to a GRPC server, which then initializes each component GRPC service with the token by passing to the ServerCallFactory objects when they are set up. When the factories create ServerCall objects for the GRPC service completion queue, this token is also passed to the ServerCall to check against inbound request metadata. The actual authentication check does not take place in this PR.

Note: This change also minorly cleans up some code in GCS server (changes a string check to use an enum). 

Next change (client-side analogue): #36526
  • Loading branch information
vitsai authored Jun 20, 2023
1 parent 14e4272 commit eb6c61c
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 59 deletions.
15 changes: 11 additions & 4 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@ config_setting(
# GRPC common lib.
cc_library(
name = "grpc_common_lib",
srcs = glob([
"src/ray/rpc/*.cc",
]),
srcs = [
"src/ray/rpc/common.cc",
"src/ray/rpc/grpc_server.cc",
"src/ray/rpc/server_call.cc",
],
hdrs = glob([
"src/ray/rpc/*.h",
"src/ray/rpc/client_call.h",
"src/ray/rpc/common.h",
"src/ray/rpc/grpc_client.h",
"src/ray/rpc/grpc_server.h",
"src/ray/rpc/metrics_agent_client.h",
"src/ray/rpc/server_call.h",
"src/ray/raylet_client/*.h",
]),
copts = COPTS,
Expand Down
1 change: 1 addition & 0 deletions src/ray/common/id_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ DEFINE_UNIQUE_ID(ActorClassID)
DEFINE_UNIQUE_ID(WorkerID)
DEFINE_UNIQUE_ID(ConfigID)
DEFINE_UNIQUE_ID(NodeID)
DEFINE_UNIQUE_ID(ClusterID)
2 changes: 2 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "ray/util/logging.h"

namespace ray {

namespace gcs {

/// \class GcsClientOptions
Expand Down Expand Up @@ -79,6 +80,7 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {

/// Connect to GCS Service. Non-thread safe.
/// This function must be called before calling other functions.
/// \param instrumented_io_context IO execution service.
///
/// \return Status
virtual Status Connect(instrumented_io_context &io_service);
Expand Down
2 changes: 1 addition & 1 deletion src/ray/gcs/gcs_client/test/gcs_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ TEST_P(GcsClientTest, DISABLED_TestGetActorPerf) {

TEST_P(GcsClientTest, TestEvictExpiredDestroyedActors) {
// Restart doesn't work with in memory storage
if (RayConfig::instance().gcs_storage() == "memory") {
if (RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage) {
return;
}
// Register actors and the actors will be destroyed.
Expand Down
55 changes: 39 additions & 16 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,23 @@
namespace ray {
namespace gcs {

inline std::ostream &operator<<(std::ostream &str, GcsServer::StorageType val) {
switch (val) {
case GcsServer::StorageType::IN_MEMORY:
return str << "StorageType::IN_MEMORY";
case GcsServer::StorageType::REDIS_PERSIST:
return str << "StorageType::REDIS_PERSIST";
case GcsServer::StorageType::UNKNOWN:
return str << "StorageType::UNKNOWN";
default:
UNREACHABLE;
}
}

GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config,
instrumented_io_context &main_service)
: config_(config),
storage_type_(StorageType()),
storage_type_(GetStorageType()),
main_service_(main_service),
rpc_server_(config.grpc_server_name,
config.grpc_server_port,
Expand All @@ -57,10 +70,15 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config,
is_stopped_(false) {
// Init GCS table storage.
RAY_LOG(INFO) << "GCS storage type is " << storage_type_;
if (storage_type_ == "redis") {
gcs_table_storage_ = std::make_shared<gcs::RedisGcsTableStorage>(GetOrConnectRedis());
} else if (storage_type_ == "memory") {
switch (storage_type_) {
case StorageType::IN_MEMORY:
gcs_table_storage_ = std::make_shared<InMemoryGcsTableStorage>(main_service_);
break;
case StorageType::REDIS_PERSIST:
gcs_table_storage_ = std::make_shared<gcs::RedisGcsTableStorage>(GetOrConnectRedis());
break;
default:
RAY_LOG(FATAL) << "Unexpected storage type: " << storage_type_;
}

auto on_done = [this](const ray::Status &status) {
Expand All @@ -73,7 +91,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config,
ray::UniqueID::Nil(), stored_config, on_done));
// Here we need to make sure the Put of internal config is happening in sync
// way. But since the storage API is async, we need to run the main_service_
// to block currenct thread.
// to block current thread.
// This will run async operations from InternalConfigTable().Put() above
// inline.
main_service_.run();
Expand Down Expand Up @@ -452,22 +470,22 @@ void GcsServer::InitGcsPlacementGroupManager(const GcsInitData &gcs_init_data) {
rpc_server_.RegisterService(*placement_group_info_service_);
}

std::string GcsServer::StorageType() const {
if (RayConfig::instance().gcs_storage() == "memory") {
GcsServer::StorageType GcsServer::GetStorageType() const {
if (RayConfig::instance().gcs_storage() == kInMemoryStorage) {
if (!config_.redis_address.empty()) {
RAY_LOG(INFO) << "Using external Redis for KV storage: " << config_.redis_address
<< ":" << config_.redis_port;
return "redis";
return StorageType::REDIS_PERSIST;
}
return "memory";
return StorageType::IN_MEMORY;
}
if (RayConfig::instance().gcs_storage() == "redis") {
if (RayConfig::instance().gcs_storage() == kRedisStorage) {
RAY_CHECK(!config_.redis_address.empty());
return "redis";
return StorageType::REDIS_PERSIST;
}
RAY_LOG(FATAL) << "Unsupported GCS storage type: "
<< RayConfig::instance().gcs_storage();
return RayConfig::instance().gcs_storage();
return StorageType::UNKNOWN;
}

void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) {
Expand Down Expand Up @@ -507,21 +525,26 @@ void GcsServer::InitUsageStatsClient() {
}

void GcsServer::InitKVManager() {
std::unique_ptr<InternalKVInterface> instance;
// TODO (yic): Use a factory with configs
if (storage_type_ == "redis") {
std::unique_ptr<InternalKVInterface> instance;
switch (storage_type_) {
case (StorageType::REDIS_PERSIST):
instance = std::make_unique<StoreClientInternalKV>(
std::make_unique<RedisStoreClient>(GetOrConnectRedis()));
} else if (storage_type_ == "memory") {
break;
case (StorageType::IN_MEMORY):
instance =
std::make_unique<StoreClientInternalKV>(std::make_unique<ObservableStoreClient>(
std::make_unique<InMemoryStoreClient>(main_service_)));
break;
default:
RAY_LOG(FATAL) << "Unexpected storage type! " << storage_type_;
}

kv_manager_ = std::make_unique<GcsInternalKVManager>(std::move(instance));
kv_service_ = std::make_unique<rpc::InternalKVGrpcService>(main_service_, *kv_manager_);
// Register service.
rpc_server_.RegisterService(*kv_service_);
rpc_server_.RegisterService(*kv_service_, false /* token_auth */);
}

void GcsServer::InitPubSubHandler() {
Expand Down
14 changes: 12 additions & 2 deletions src/ray/gcs/gcs_server/gcs_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ class GcsServer {
/// Check if gcs server is stopped.
bool IsStopped() const { return is_stopped_; }

// TODO(vitsai): string <=> enum generator macro
enum class StorageType {
UNKNOWN = 0,
IN_MEMORY = 1,
REDIS_PERSIST = 2,
};

static constexpr char kInMemoryStorage[] = "memory";
static constexpr char kRedisStorage[] = "redis";

protected:
/// Generate the redis client options
RedisClientOptions GetRedisClientOptions() const;
Expand Down Expand Up @@ -161,7 +171,7 @@ class GcsServer {

private:
/// Gets the type of KV storage to use from config.
std::string StorageType() const;
StorageType GetStorageType() const;

/// Print debug info periodically.
std::string GetDebugState() const;
Expand All @@ -183,7 +193,7 @@ class GcsServer {
/// Gcs server configuration.
const GcsServerConfig config_;
// Type of storage to use.
const std::string storage_type_;
const StorageType storage_type_;
/// The main io service to drive event posted from grpc threads.
instrumented_io_context &main_service_;
/// The io service used by Pubsub, for isolation from other workload.
Expand Down
3 changes: 2 additions & 1 deletion src/ray/rpc/agent_manager/agent_manager_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class AgentManagerGrpcService : public GrpcService {

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
RAY_AGENT_MANAGER_RPC_HANDLERS
}

Expand Down
37 changes: 25 additions & 12 deletions src/ray/rpc/gcs_server/gcs_rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ using namespace rpc::autoscaler;
HANDLER, \
RayConfig::instance().gcs_max_active_rpcs_per_handler())

// TODO(vitsai): Set auth for everything except GCS.
#define INTERNAL_KV_SERVICE_RPC_HANDLER(HANDLER) \
RPC_SERVICE_HANDLER(InternalKVGcsService, HANDLER, -1)

Expand Down Expand Up @@ -134,7 +135,8 @@ class JobInfoGrpcService : public GrpcService {

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
JOB_INFO_SERVICE_RPC_HANDLER(AddJob);
JOB_INFO_SERVICE_RPC_HANDLER(MarkJobFinished);
JOB_INFO_SERVICE_RPC_HANDLER(GetAllJobInfo);
Expand Down Expand Up @@ -197,7 +199,8 @@ class ActorInfoGrpcService : public GrpcService {

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
/// Register/Create Actor RPC takes long time, we shouldn't limit them to avoid
/// distributed deadlock.
ACTOR_INFO_SERVICE_RPC_HANDLER(RegisterActor, -1);
Expand Down Expand Up @@ -255,7 +258,8 @@ class MonitorGrpcService : public GrpcService {

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
MONITOR_SERVICE_RPC_HANDLER(GetRayVersion);
MONITOR_SERVICE_RPC_HANDLER(DrainAndKillNode);
MONITOR_SERVICE_RPC_HANDLER(GetSchedulingStatus);
Expand Down Expand Up @@ -308,7 +312,8 @@ class NodeInfoGrpcService : public GrpcService {

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
NODE_INFO_SERVICE_RPC_HANDLER(RegisterNode);
NODE_INFO_SERVICE_RPC_HANDLER(DrainNode);
NODE_INFO_SERVICE_RPC_HANDLER(GetAllNodeInfo);
Expand Down Expand Up @@ -360,7 +365,8 @@ class NodeResourceInfoGrpcService : public GrpcService {

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
NODE_RESOURCE_INFO_SERVICE_RPC_HANDLER(GetResources);
NODE_RESOURCE_INFO_SERVICE_RPC_HANDLER(GetAllAvailableResources);
NODE_RESOURCE_INFO_SERVICE_RPC_HANDLER(ReportResourceUsage);
Expand Down Expand Up @@ -410,7 +416,8 @@ class WorkerInfoGrpcService : public GrpcService {

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
WORKER_INFO_SERVICE_RPC_HANDLER(ReportWorkerFailure);
WORKER_INFO_SERVICE_RPC_HANDLER(GetWorkerInfo);
WORKER_INFO_SERVICE_RPC_HANDLER(GetAllWorkerInfo);
Expand Down Expand Up @@ -456,7 +463,8 @@ class AutoscalerStateGrpcService : public GrpcService {
grpc::Service &GetGrpcService() override { return service_; }
void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
AUTOSCALER_STATE_SERVICE_RPC_HANDLER(GetClusterResourceState);
AUTOSCALER_STATE_SERVICE_RPC_HANDLER(ReportAutoscalingState);
AUTOSCALER_STATE_SERVICE_RPC_HANDLER(RequestClusterResourceConstraint);
Expand Down Expand Up @@ -514,7 +522,8 @@ class PlacementGroupInfoGrpcService : public GrpcService {

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(CreatePlacementGroup);
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(RemovePlacementGroup);
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetPlacementGroup);
Expand Down Expand Up @@ -568,7 +577,8 @@ class InternalKVGrpcService : public GrpcService {
grpc::Service &GetGrpcService() override { return service_; }
void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
INTERNAL_KV_SERVICE_RPC_HANDLER(InternalKVGet);
INTERNAL_KV_SERVICE_RPC_HANDLER(InternalKVMultiGet);
INTERNAL_KV_SERVICE_RPC_HANDLER(InternalKVPut);
Expand Down Expand Up @@ -600,7 +610,8 @@ class RuntimeEnvGrpcService : public GrpcService {
grpc::Service &GetGrpcService() override { return service_; }
void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
RUNTIME_ENV_SERVICE_RPC_HANDLER(PinRuntimeEnvURI);
}

Expand Down Expand Up @@ -638,7 +649,8 @@ class TaskInfoGrpcService : public GrpcService {

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
TASK_INFO_SERVICE_RPC_HANDLER(AddTaskEventData);
TASK_INFO_SERVICE_RPC_HANDLER(GetTaskEvents);
}
Expand Down Expand Up @@ -677,7 +689,8 @@ class InternalPubSubGrpcService : public GrpcService {
grpc::Service &GetGrpcService() override { return service_; }
void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsPublish);
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberPoll);
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberCommandBatch);
Expand Down
17 changes: 3 additions & 14 deletions src/ray/rpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,13 @@

#include "ray/common/ray_config.h"
#include "ray/rpc/common.h"
#include "ray/rpc/grpc_server.h"
#include "ray/stats/metric.h"
#include "ray/util/util.h"

namespace ray {
namespace rpc {

GrpcServer::GrpcServer(std::string name,
const uint32_t port,
bool listen_to_localhost_only,
int num_threads,
int64_t keepalive_time_ms)
: name_(std::move(name)),
port_(port),
listen_to_localhost_only_(listen_to_localhost_only),
is_closed_(true),
num_threads_(num_threads),
keepalive_time_ms_(keepalive_time_ms) {
void GrpcServer::Init() {
RAY_CHECK(num_threads_ > 0) << "Num of threads in gRPC must be greater than 0";
cqs_.resize(num_threads_);
// Enable built in health check implemented by gRPC:
Expand Down Expand Up @@ -166,11 +155,11 @@ void GrpcServer::RegisterService(grpc::Service &service) {
services_.emplace_back(service);
}

void GrpcServer::RegisterService(GrpcService &service) {
void GrpcServer::RegisterService(GrpcService &service, bool token_auth) {
services_.emplace_back(service.GetGrpcService());

for (int i = 0; i < num_threads_; i++) {
service.InitServerCallFactories(cqs_[i], &server_call_factories_);
service.InitServerCallFactories(cqs_[i], &server_call_factories_, cluster_id_.load());
}
}

Expand Down
Loading

0 comments on commit eb6c61c

Please sign in to comment.