diff --git a/src/ray/common/status.cc b/src/ray/common/status.cc index fb66ef4acfee..68b3cd39b3e7 100644 --- a/src/ray/common/status.cc +++ b/src/ray/common/status.cc @@ -54,6 +54,7 @@ namespace ray { #define STATUS_CODE_NOT_FOUND "NotFound" #define STATUS_CODE_DISCONNECTED "Disconnected" #define STATUS_CODE_SCHEDULING_CANCELLED "SchedulingCancelled" +#define STATUS_CODE_AUTH_ERROR "AuthError" // object store status #define STATUS_CODE_OBJECT_EXISTS "ObjectExists" #define STATUS_CODE_OBJECT_NOT_FOUND "ObjectNotFound" @@ -114,6 +115,7 @@ std::string Status::CodeAsString() const { {StatusCode::TransientObjectStoreFull, STATUS_CODE_TRANSIENT_OBJECT_STORE_FULL}, {StatusCode::GrpcUnavailable, STATUS_CODE_GRPC_UNAVAILABLE}, {StatusCode::GrpcUnknown, STATUS_CODE_GRPC_UNKNOWN}, + {StatusCode::AuthError, STATUS_CODE_AUTH_ERROR}, }; auto it = code_to_str.find(code()); @@ -149,6 +151,7 @@ StatusCode Status::StringToCode(const std::string &str) { {STATUS_CODE_OBJECT_UNKNOWN_OWNER, StatusCode::ObjectUnknownOwner}, {STATUS_CODE_OBJECT_STORE_FULL, StatusCode::ObjectStoreFull}, {STATUS_CODE_TRANSIENT_OBJECT_STORE_FULL, StatusCode::TransientObjectStoreFull}, + {STATUS_CODE_AUTH_ERROR, StatusCode::AuthError}, }; auto it = str_to_code.find(str); diff --git a/src/ray/common/status.h b/src/ray/common/status.h index cfbcff3dfc89..591321c7f462 100644 --- a/src/ray/common/status.h +++ b/src/ray/common/status.h @@ -115,8 +115,8 @@ enum class StatusCode : char { ObjectUnknownOwner = 29, RpcError = 30, OutOfResource = 31, - // Meaning the ObjectRefStream reaches to the end of stream. - ObjectRefEndOfStream = 32 + ObjectRefEndOfStream = 32, + AuthError = 33, }; #if defined(__clang__) @@ -252,6 +252,10 @@ class RAY_EXPORT Status { return Status(StatusCode::OutOfResource, msg); } + static Status AuthError(const std::string &msg) { + return Status(StatusCode::AuthError, msg); + } + static StatusCode StringToCode(const std::string &str); // Returns true iff the status indicates success. @@ -303,6 +307,8 @@ class RAY_EXPORT Status { bool IsOutOfResource() const { return code() == StatusCode::OutOfResource; } + bool IsAuthError() const { return code() == StatusCode::AuthError; } + // Return a string representation of this status suitable for printing. // Returns the string "OK" for success. std::string ToString() const; diff --git a/src/ray/gcs/gcs_client/accessor.h b/src/ray/gcs/gcs_client/accessor.h index 93697d11cf0c..dada909f7fd0 100644 --- a/src/ray/gcs/gcs_client/accessor.h +++ b/src/ray/gcs/gcs_client/accessor.h @@ -269,7 +269,7 @@ class JobInfoAccessor { class NodeInfoAccessor { public: NodeInfoAccessor() = default; - explicit NodeInfoAccessor(GcsClient *client_impl); + NodeInfoAccessor(GcsClient *client_impl); virtual ~NodeInfoAccessor() = default; /// Register local node to GCS asynchronously. /// diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index 4e75a1799467..3eeefb624573 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -149,7 +149,6 @@ Status GcsClient::Connect(instrumented_io_context &io_service, // Init GCS subscriber instance. gcs_subscriber_ = std::make_unique(gcs_address, std::move(subscriber)); - job_accessor_ = std::make_unique(this); actor_accessor_ = std::make_unique(this); node_accessor_ = std::make_unique(this); diff --git a/src/ray/gcs/gcs_client/test/gcs_client_test.cc b/src/ray/gcs/gcs_client/test/gcs_client_test.cc index 6a315f91e4fc..1c80f3035a25 100644 --- a/src/ray/gcs/gcs_client/test/gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/gcs_client_test.cc @@ -158,7 +158,8 @@ class GcsClientTest : public ::testing::TestWithParam { rpc::CheckAliveReply reply; auto status = stub->CheckAlive(&context, request, &reply); // If it is in memory, we don't have the new token until we connect again. - if (!status.ok()) { + if (!((!in_memory && !status.ok()) || + (in_memory && GrpcStatusToRayStatus(status).IsAuthError()))) { RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " " << status.error_message(); continue; @@ -885,6 +886,7 @@ TEST_P(GcsClientTest, TestGcsTableReload) { // Restart GCS. RestartGcsServer(); + RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_)); // Get information of nodes from GCS. std::vector node_list = GetNodeInfoList(); @@ -981,6 +983,7 @@ TEST_P(GcsClientTest, TestEvictExpiredDestroyedActors) { // Restart GCS. RestartGcsServer(); + RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_)); for (int index = 0; index < actor_count; ++index) { auto actor_table_data = Mocker::GenActorTableData(job_id); @@ -1008,6 +1011,7 @@ TEST_P(GcsClientTest, TestGcsAuth) { RestartGcsServer(); auto node_info = Mocker::GenNodeInfo(); + EXPECT_FALSE(RegisterNode(*node_info)); RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_)); EXPECT_TRUE(RegisterNode(*node_info)); } @@ -1015,6 +1019,7 @@ TEST_P(GcsClientTest, TestGcsAuth) { TEST_P(GcsClientTest, TestEvictExpiredDeadNodes) { // Restart GCS. RestartGcsServer(); + RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_)); if (RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage) { RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_)); } diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 6a6c1a7f0f9b..4fc056bcf4b7 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -15,6 +15,7 @@ #include "ray/gcs/gcs_server/gcs_server.h" #include +#include #include "ray/common/asio/asio_util.h" #include "ray/common/asio/instrumented_io_context.h" @@ -59,6 +60,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, rpc_server_(config.grpc_server_name, config.grpc_server_port, config.node_ip_address == "127.0.0.1", + ClusterID::Nil(), config.grpc_server_thread_num, /*keepalive_time_ms=*/RayConfig::instance().grpc_keepalive_time_ms()), client_call_manager_(main_service, @@ -157,11 +159,11 @@ void GcsServer::GetOrGenerateClusterId( kv_manager_->GetInstance().Get( kTokenNamespace, kClusterIdKey, - [this, continuation = std::move(continuation)]( - std::optional provided_cluster_id) mutable { - if (!provided_cluster_id.has_value()) { + [this, + continuation = std::move(continuation)](std::optional token) mutable { + if (!token.has_value()) { ClusterID cluster_id = ClusterID::FromRandom(); - RAY_LOG(INFO) << "No existing server cluster ID found. Generating new ID: " + RAY_LOG(INFO) << "No existing server token found. Generating new token: " << cluster_id.Hex(); kv_manager_->GetInstance().Put( kTokenNamespace, @@ -170,11 +172,11 @@ void GcsServer::GetOrGenerateClusterId( false, [&cluster_id, continuation = std::move(continuation)](bool added_entry) mutable { - RAY_CHECK(added_entry) << "Failed to persist new cluster ID!"; + RAY_CHECK(added_entry) << "Failed to persist new token!"; continuation(cluster_id); }); } else { - ClusterID cluster_id = ClusterID::FromBinary(provided_cluster_id.value()); + ClusterID cluster_id = ClusterID::FromBinary(token.value()); RAY_LOG(INFO) << "Found existing server token: " << cluster_id; continuation(cluster_id); } diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index b80f1f906f6d..5dfe284af5ea 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/ray_syncer/ray_syncer.h" #include "ray/common/runtime_env_manager.h" diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 44e6ad7c1914..17880750173c 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -100,6 +100,7 @@ ObjectManager::ObjectManager( object_manager_server_("ObjectManager", config_.object_manager_port, config_.object_manager_address == "127.0.0.1", + ClusterID::Nil(), config_.rpc_service_threads_number), object_manager_service_(rpc_service_, *this), client_call_manager_( diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 24047ea43279..aeb111432ce6 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -170,152 +170,155 @@ int main(int argc, char *argv[]) { RAY_CHECK_OK(gcs_client->Connect(main_service)); std::unique_ptr raylet; - RAY_CHECK_OK(gcs_client->Nodes().AsyncGetInternalConfig( - [&](::ray::Status status, - const boost::optional &stored_raylet_config) { - RAY_CHECK_OK(status); - RAY_CHECK(stored_raylet_config.has_value()); - RayConfig::instance().initialize(stored_raylet_config.get()); + auto f = std::async(std::launch::async, [&]() { + RAY_CHECK_OK(gcs_client->Nodes().AsyncGetInternalConfig( + [&](::ray::Status status, + const boost::optional &stored_raylet_config) { + RAY_CHECK_OK(status); + RAY_CHECK(stored_raylet_config.has_value()); + RayConfig::instance().initialize(stored_raylet_config.get()); - // Parse the worker port list. - std::istringstream worker_port_list_string(worker_port_list); - std::string worker_port; - std::vector worker_ports; + // Parse the worker port list. + std::istringstream worker_port_list_string(worker_port_list); + std::string worker_port; + std::vector worker_ports; - while (std::getline(worker_port_list_string, worker_port, ',')) { - worker_ports.push_back(std::stoi(worker_port)); - } + while (std::getline(worker_port_list_string, worker_port, ',')) { + worker_ports.push_back(std::stoi(worker_port)); + } - // Parse the resource list. - std::istringstream resource_string(static_resource_list); - std::string resource_name; - std::string resource_quantity; + // Parse the resource list. + std::istringstream resource_string(static_resource_list); + std::string resource_name; + std::string resource_quantity; - while (std::getline(resource_string, resource_name, ',')) { - RAY_CHECK(std::getline(resource_string, resource_quantity, ',')); - static_resource_conf[resource_name] = std::stod(resource_quantity); - } - auto num_cpus_it = static_resource_conf.find("CPU"); - int num_cpus = num_cpus_it != static_resource_conf.end() - ? static_cast(num_cpus_it->second) - : 0; + while (std::getline(resource_string, resource_name, ',')) { + RAY_CHECK(std::getline(resource_string, resource_quantity, ',')); + static_resource_conf[resource_name] = std::stod(resource_quantity); + } + auto num_cpus_it = static_resource_conf.find("CPU"); + int num_cpus = num_cpus_it != static_resource_conf.end() + ? static_cast(num_cpus_it->second) + : 0; - node_manager_config.raylet_config = stored_raylet_config.get(); - node_manager_config.resource_config = - ray::ResourceMapToResourceRequest(std::move(static_resource_conf), false); - RAY_LOG(DEBUG) << "Starting raylet with static resource configuration: " - << node_manager_config.resource_config.DebugString(); - node_manager_config.node_manager_address = node_ip_address; - node_manager_config.node_manager_port = node_manager_port; - node_manager_config.num_workers_soft_limit = - RayConfig::instance().num_workers_soft_limit(); - node_manager_config.num_prestart_python_workers = num_prestart_python_workers; - node_manager_config.maximum_startup_concurrency = maximum_startup_concurrency; - node_manager_config.min_worker_port = min_worker_port; - node_manager_config.max_worker_port = max_worker_port; - node_manager_config.worker_ports = worker_ports; - node_manager_config.labels = parse_node_labels(labels_json_str); + node_manager_config.raylet_config = stored_raylet_config.get(); + node_manager_config.resource_config = + ray::ResourceMapToResourceRequest(std::move(static_resource_conf), false); + RAY_LOG(DEBUG) << "Starting raylet with static resource configuration: " + << node_manager_config.resource_config.DebugString(); + node_manager_config.node_manager_address = node_ip_address; + node_manager_config.node_manager_port = node_manager_port; + node_manager_config.num_workers_soft_limit = + RayConfig::instance().num_workers_soft_limit(); + node_manager_config.num_prestart_python_workers = num_prestart_python_workers; + node_manager_config.maximum_startup_concurrency = maximum_startup_concurrency; + node_manager_config.min_worker_port = min_worker_port; + node_manager_config.max_worker_port = max_worker_port; + node_manager_config.worker_ports = worker_ports; + node_manager_config.labels = parse_node_labels(labels_json_str); - if (!python_worker_command.empty()) { - node_manager_config.worker_commands.emplace( - make_pair(ray::Language::PYTHON, ParseCommandLine(python_worker_command))); - } - if (!java_worker_command.empty()) { - node_manager_config.worker_commands.emplace( - make_pair(ray::Language::JAVA, ParseCommandLine(java_worker_command))); - } - if (!cpp_worker_command.empty()) { - node_manager_config.worker_commands.emplace( - make_pair(ray::Language::CPP, ParseCommandLine(cpp_worker_command))); - } - node_manager_config.native_library_path = native_library_path; - if (python_worker_command.empty() && java_worker_command.empty() && - cpp_worker_command.empty()) { - RAY_LOG(FATAL) << "At least one of Python/Java/CPP worker command " - << "should be provided"; - } - if (!agent_command.empty()) { - node_manager_config.agent_command = agent_command; - } else { - RAY_LOG(DEBUG) << "Agent command is empty. Not starting agent."; - } + if (!python_worker_command.empty()) { + node_manager_config.worker_commands.emplace(make_pair( + ray::Language::PYTHON, ParseCommandLine(python_worker_command))); + } + if (!java_worker_command.empty()) { + node_manager_config.worker_commands.emplace( + make_pair(ray::Language::JAVA, ParseCommandLine(java_worker_command))); + } + if (!cpp_worker_command.empty()) { + node_manager_config.worker_commands.emplace( + make_pair(ray::Language::CPP, ParseCommandLine(cpp_worker_command))); + } + node_manager_config.native_library_path = native_library_path; + if (python_worker_command.empty() && java_worker_command.empty() && + cpp_worker_command.empty()) { + RAY_LOG(FATAL) << "At least one of Python/Java/CPP worker command " + << "should be provided"; + } + if (!agent_command.empty()) { + node_manager_config.agent_command = agent_command; + } else { + RAY_LOG(DEBUG) << "Agent command is empty. Not starting agent."; + } - node_manager_config.report_resources_period_ms = - RayConfig::instance().raylet_report_resources_period_milliseconds(); - node_manager_config.record_metrics_period_ms = - RayConfig::instance().metrics_report_interval_ms() / 2; - node_manager_config.store_socket_name = store_socket_name; - node_manager_config.temp_dir = temp_dir; - node_manager_config.log_dir = log_dir; - node_manager_config.session_dir = session_dir; - node_manager_config.resource_dir = resource_dir; - node_manager_config.ray_debugger_external = ray_debugger_external; - node_manager_config.max_io_workers = RayConfig::instance().max_io_workers(); - node_manager_config.min_spilling_size = RayConfig::instance().min_spilling_size(); + node_manager_config.report_resources_period_ms = + RayConfig::instance().raylet_report_resources_period_milliseconds(); + node_manager_config.record_metrics_period_ms = + RayConfig::instance().metrics_report_interval_ms() / 2; + node_manager_config.store_socket_name = store_socket_name; + node_manager_config.temp_dir = temp_dir; + node_manager_config.log_dir = log_dir; + node_manager_config.session_dir = session_dir; + node_manager_config.resource_dir = resource_dir; + node_manager_config.ray_debugger_external = ray_debugger_external; + node_manager_config.max_io_workers = RayConfig::instance().max_io_workers(); + node_manager_config.min_spilling_size = + RayConfig::instance().min_spilling_size(); - // Configuration for the object manager. - ray::ObjectManagerConfig object_manager_config; - object_manager_config.object_manager_address = node_ip_address; - object_manager_config.object_manager_port = object_manager_port; - object_manager_config.store_socket_name = store_socket_name; + // Configuration for the object manager. + ray::ObjectManagerConfig object_manager_config; + object_manager_config.object_manager_address = node_ip_address; + object_manager_config.object_manager_port = object_manager_port; + object_manager_config.store_socket_name = store_socket_name; - object_manager_config.timer_freq_ms = - RayConfig::instance().object_manager_timer_freq_ms(); - object_manager_config.pull_timeout_ms = - RayConfig::instance().object_manager_pull_timeout_ms(); - object_manager_config.push_timeout_ms = - RayConfig::instance().object_manager_push_timeout_ms(); - if (object_store_memory <= 0) { - RAY_LOG(FATAL) << "Object store memory should be set."; - } - object_manager_config.object_store_memory = object_store_memory; - object_manager_config.max_bytes_in_flight = - RayConfig::instance().object_manager_max_bytes_in_flight(); - object_manager_config.plasma_directory = plasma_directory; - object_manager_config.fallback_directory = temp_dir; - object_manager_config.huge_pages = huge_pages; + object_manager_config.timer_freq_ms = + RayConfig::instance().object_manager_timer_freq_ms(); + object_manager_config.pull_timeout_ms = + RayConfig::instance().object_manager_pull_timeout_ms(); + object_manager_config.push_timeout_ms = + RayConfig::instance().object_manager_push_timeout_ms(); + if (object_store_memory <= 0) { + RAY_LOG(FATAL) << "Object store memory should be set."; + } + object_manager_config.object_store_memory = object_store_memory; + object_manager_config.max_bytes_in_flight = + RayConfig::instance().object_manager_max_bytes_in_flight(); + object_manager_config.plasma_directory = plasma_directory; + object_manager_config.fallback_directory = temp_dir; + object_manager_config.huge_pages = huge_pages; - object_manager_config.rpc_service_threads_number = - std::min(std::max(2, num_cpus / 4), 8); - object_manager_config.object_chunk_size = - RayConfig::instance().object_manager_default_chunk_size(); + object_manager_config.rpc_service_threads_number = + std::min(std::max(2, num_cpus / 4), 8); + object_manager_config.object_chunk_size = + RayConfig::instance().object_manager_default_chunk_size(); - RAY_LOG(DEBUG) << "Starting object manager with configuration: \n" - << "rpc_service_threads_number = " - << object_manager_config.rpc_service_threads_number - << ", object_chunk_size = " - << object_manager_config.object_chunk_size; - // Initialize stats. - const ray::stats::TagsType global_tags = { - {ray::stats::ComponentKey, "raylet"}, - {ray::stats::WorkerIdKey, ""}, - {ray::stats::VersionKey, kRayVersion}, - {ray::stats::NodeAddressKey, node_ip_address}, - {ray::stats::SessionNameKey, session_name}}; - ray::stats::Init(global_tags, metrics_agent_port, WorkerID::Nil()); + RAY_LOG(DEBUG) << "Starting object manager with configuration: \n" + << "rpc_service_threads_number = " + << object_manager_config.rpc_service_threads_number + << ", object_chunk_size = " + << object_manager_config.object_chunk_size; + // Initialize stats. + const ray::stats::TagsType global_tags = { + {ray::stats::ComponentKey, "raylet"}, + {ray::stats::WorkerIdKey, ""}, + {ray::stats::VersionKey, kRayVersion}, + {ray::stats::NodeAddressKey, node_ip_address}, + {ray::stats::SessionNameKey, session_name}}; + ray::stats::Init(global_tags, metrics_agent_port, WorkerID::Nil()); - // Initialize the node manager. - raylet = std::make_unique(main_service, - raylet_socket_name, - node_ip_address, - node_name, - node_manager_config, - object_manager_config, - gcs_client, - metrics_export_port, - is_head_node); + // Initialize the node manager. + raylet = std::make_unique(main_service, + raylet_socket_name, + node_ip_address, + node_name, + node_manager_config, + object_manager_config, + gcs_client, + metrics_export_port, + is_head_node); - // Initialize event framework. - if (RayConfig::instance().event_log_reporter_enabled() && !log_dir.empty()) { - ray::RayEventInit(ray::rpc::Event_SourceType::Event_SourceType_RAYLET, - {{"node_id", raylet->GetNodeId().Hex()}}, - log_dir, - RayConfig::instance().event_level(), - RayConfig::instance().emit_event_to_log_file()); - }; + // Initialize event framework. + if (RayConfig::instance().event_log_reporter_enabled() && !log_dir.empty()) { + ray::RayEventInit(ray::rpc::Event_SourceType::Event_SourceType_RAYLET, + {{"node_id", raylet->GetNodeId().Hex()}}, + log_dir, + RayConfig::instance().event_level(), + RayConfig::instance().emit_event_to_log_file()); + }; - raylet->Start(); - })); + raylet->Start(); + })); + }); auto shutted_down = std::make_shared>(false); diff --git a/src/ray/rpc/agent_manager/agent_manager_server.h b/src/ray/rpc/agent_manager/agent_manager_server.h index 4fb5dd02464b..6a373d4f658d 100644 --- a/src/ray/rpc/agent_manager/agent_manager_server.h +++ b/src/ray/rpc/agent_manager/agent_manager_server.h @@ -24,7 +24,8 @@ namespace ray { namespace rpc { #define RAY_AGENT_MANAGER_RPC_HANDLERS \ - RPC_SERVICE_HANDLER(AgentManagerService, RegisterAgent, -1) + RPC_SERVICE_HANDLER_CUSTOM_AUTH( \ + AgentManagerService, RegisterAgent, -1, AuthType::NO_AUTH) /// Implementations of the `AgentManagerGrpcService`, check interface in /// `src/ray/protobuf/agent_manager.proto`. diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index a9c52fc3717e..ef050b6a33b2 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "absl/synchronization/mutex.h" #include "ray/common/asio/instrumented_io_context.h" @@ -193,6 +194,7 @@ class ClientCallManager { /// /// \param[in] main_service The main event loop, to which the callback functions will be /// posted. + /// explicit ClientCallManager(instrumented_io_context &main_service, const ClusterID &cluster_id = ClusterID::Nil(), int num_threads = 1, @@ -238,7 +240,7 @@ class ClientCallManager { /// -1 means it will use the default timeout configured for the handler. /// /// \return A `ClientCall` representing the request that was just sent. - template + template std::shared_ptr CreateCall( typename GrpcService::Stub &stub, const PrepareAsyncFunction prepare_async_function, @@ -251,8 +253,15 @@ class ClientCallManager { method_timeout_ms = call_timeout_ms_; } + ClusterID cluster_id; + if constexpr (Insecure) { + cluster_id = ClusterID::Nil(); + } else { + cluster_id = cluster_id_; + } + auto call = std::make_shared>( - callback, cluster_id_, std::move(stats_handle), method_timeout_ms); + callback, cluster_id, std::move(stats_handle), method_timeout_ms); // Send request. // Find the next completion queue to wait for response. call->response_reader_ = (stub.*prepare_async_function)( diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index e021d4287c55..d40232d56da1 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -87,11 +87,12 @@ class Executor { /// The priority of timeout is each call > handler > whole service /// (the lower priority timeout is overwritten by the higher priority timeout). /// \param SPECS The cpp method spec. For example, override. +/// \param IS_INSECURE Whether to attach a cluster_id token to the metadata of the call. /// /// Currently, SyncMETHOD will copy the reply additionally. /// TODO(sang): Fix it. -#define VOID_GCS_RPC_CLIENT_METHOD( \ - SERVICE, METHOD, grpc_client, method_timeout_ms, SPECS) \ +#define _VOID_GCS_RPC_CLIENT_METHOD( \ + SERVICE, METHOD, grpc_client, method_timeout_ms, SPECS, IS_INSECURE) \ void METHOD(const METHOD##Request &request, \ const ClientCallback &callback, \ const int64_t timeout_ms = method_timeout_ms) SPECS { \ @@ -148,12 +149,13 @@ class Executor { }; \ auto operation = \ [request, operation_callback, timeout_ms](GcsRpcClient *gcs_rpc_client) { \ - RAY_UNUSED(INVOKE_RPC_CALL(SERVICE, \ - METHOD, \ - request, \ - operation_callback, \ - gcs_rpc_client->grpc_client, \ - timeout_ms)); \ + RAY_UNUSED(_INVOKE_RPC_CALL(SERVICE, \ + METHOD, \ + request, \ + operation_callback, \ + gcs_rpc_client->grpc_client, \ + timeout_ms, \ + IS_INSECURE)); \ }; \ executor->Execute(std::move(operation)); \ } \ @@ -171,6 +173,16 @@ class Executor { return promise.get_future().get(); \ } +#define VOID_GCS_RPC_CLIENT_METHOD( \ + SERVICE, METHOD, grpc_client, method_timeout_ms, SPECS) \ + _VOID_GCS_RPC_CLIENT_METHOD( \ + SERVICE, METHOD, grpc_client, method_timeout_ms, SPECS, false) + +#define VOID_GCS_RPC_CLIENT_METHOD_NO_AUTH( \ + SERVICE, METHOD, grpc_client, method_timeout_ms, SPECS) \ + _VOID_GCS_RPC_CLIENT_METHOD( \ + SERVICE, METHOD, grpc_client, method_timeout_ms, SPECS, true) + /// Client used for communicating with gcs server. class GcsRpcClient { public: @@ -189,7 +201,7 @@ class GcsRpcClient { public: /// Constructor. GcsRpcClient is not thread safe. /// - /// \param[in] address Address of gcs server. + // \param[in] address Address of gcs server. /// \param[in] port Port of the gcs server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. /// \param[in] gcs_service_failure_detected The function is used to redo subscription diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index b7b8a3f55557..933556719a0a 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -77,7 +77,7 @@ using namespace rpc::autoscaler; // TODO(vitsai): Set auth for everything except GCS. #define INTERNAL_KV_SERVICE_RPC_HANDLER(HANDLER) \ - RPC_SERVICE_HANDLER(InternalKVGcsService, HANDLER, -1) + RPC_SERVICE_HANDLER_CUSTOM_AUTH(InternalKVGcsService, HANDLER, -1, AuthType::NO_AUTH) #define RUNTIME_ENV_SERVICE_RPC_HANDLER(HANDLER) \ RPC_SERVICE_HANDLER(RuntimeEnvGcsService, HANDLER, -1) @@ -318,7 +318,11 @@ class NodeInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id) override { - NODE_INFO_SERVICE_RPC_HANDLER(GetClusterId); + RPC_SERVICE_HANDLER_CUSTOM_AUTH( + NodeInfoGcsService, + GetClusterId, + RayConfig::instance().gcs_max_active_rpcs_per_handler(), + AuthType::LAZY_AUTH); NODE_INFO_SERVICE_RPC_HANDLER(RegisterNode); NODE_INFO_SERVICE_RPC_HANDLER(DrainNode); NODE_INFO_SERVICE_RPC_HANDLER(GetAllNodeInfo); diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 0f0743282257..b71ce93a2529 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -29,14 +29,19 @@ namespace rpc { // This macro wraps the logic to call a specific RPC method of a service, // to make it easier to implement a new RPC client. +#define _INVOKE_RPC_CALL( \ + SERVICE, METHOD, request, callback, rpc_client, method_timeout_ms, IS_INSECURE) \ + (rpc_client->CallMethod( \ + &SERVICE::Stub::PrepareAsync##METHOD, \ + request, \ + callback, \ + #SERVICE ".grpc_client." #METHOD, \ + method_timeout_ms)) + #define INVOKE_RPC_CALL( \ SERVICE, METHOD, request, callback, rpc_client, method_timeout_ms) \ - (rpc_client->CallMethod( \ - &SERVICE::Stub::PrepareAsync##METHOD, \ - request, \ - callback, \ - #SERVICE ".grpc_client." #METHOD, \ - method_timeout_ms)) + _INVOKE_RPC_CALL( \ + SERVICE, METHOD, request, callback, rpc_client, method_timeout_ms, false); // Define a void RPC client method. #define VOID_RPC_CLIENT_METHOD(SERVICE, METHOD, rpc_client, method_timeout_ms, SPECS) \ @@ -136,14 +141,14 @@ class GrpcClient { /// -1 means it will use the default timeout configured for the handler. /// /// \return Status. - template + template void CallMethod( const PrepareAsyncFunction prepare_async_function, const Request &request, const ClientCallback &callback, std::string call_name = "UNKNOWN_RPC", int64_t method_timeout_ms = -1) { - auto call = client_call_manager_.CreateCall( + auto call = client_call_manager_.CreateCall( *stub_, prepare_async_function, request, diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 89ce79db734e..e316f101b12f 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -28,31 +28,42 @@ namespace ray { namespace rpc { /// \param MAX_ACTIVE_RPCS Maximum number of RPCs to handle at the same time. -1 means no /// limit. -#define _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, RECORD_METRICS) \ - std::unique_ptr HANDLER##_call_factory( \ - new ServerCallFactoryImpl( \ - service_, \ - &SERVICE::AsyncService::Request##HANDLER, \ - service_handler_, \ - &SERVICE##Handler::Handle##HANDLER, \ - cq, \ - main_service_, \ - #SERVICE ".grpc_server." #HANDLER, \ - cluster_id, \ - MAX_ACTIVE_RPCS, \ - RECORD_METRICS)); \ +#define _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE, RECORD_METRICS) \ + std::unique_ptr HANDLER##_call_factory( \ + new ServerCallFactoryImpl( \ + service_, \ + &SERVICE::AsyncService::Request##HANDLER, \ + service_handler_, \ + &SERVICE##Handler::Handle##HANDLER, \ + cq, \ + main_service_, \ + #SERVICE ".grpc_server." #HANDLER, \ + AUTH_TYPE == AuthType::NO_AUTH ? ClusterID::Nil() : cluster_id, \ + MAX_ACTIVE_RPCS, \ + RECORD_METRICS)); \ server_call_factories->emplace_back(std::move(HANDLER##_call_factory)); /// Define a RPC service handler with gRPC server metrics enabled. #define RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ - _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, true) + _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AuthType::STRICT_AUTH, true) /// Define a RPC service handler with gRPC server metrics disabled. #define RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ - _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, false) + _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AuthType::STRICT_AUTH, false) + +/// Define a RPC service handler with gRPC server metrics enabled. +#define RPC_SERVICE_HANDLER_CUSTOM_AUTH(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE) \ + _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE, true) + +/// Define a RPC service handler with gRPC server metrics disabled. +#define RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE) \ + _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE, false) // Define a void RPC client method. #define DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(METHOD) \ @@ -82,6 +93,7 @@ class GrpcServer { GrpcServer(std::string name, const uint32_t port, bool listen_to_localhost_only, + const ClusterID &cluster_id = ClusterID::Nil(), int num_threads = 1, int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/) : name_(std::move(name)), diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index eb8da2f17c91..b25167569bcc 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -23,31 +23,35 @@ namespace ray { namespace rpc { +/// TODO(vitsai): Remove this when auth is implemented for node manager +#define RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(METHOD) \ + RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeManagerService, METHOD, -1, AuthType::NO_AUTH) + /// NOTE: See src/ray/core_worker/core_worker.h on how to add a new grpc handler. -#define RAY_NODE_MANAGER_RPC_HANDLERS \ - RPC_SERVICE_HANDLER(NodeManagerService, UpdateResourceUsage, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, RequestResourceReport, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, GetResourceLoad, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, NotifyGCSRestart, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, RequestWorkerLease, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, ReportWorkerBacklog, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, ReturnWorker, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, ReleaseUnusedWorkers, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, CancelWorkerLease, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, PinObjectIDs, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, GetNodeStats, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, GlobalGC, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, FormatGlobalMemoryInfo, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, PrepareBundleResources, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, CommitBundleResources, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, CancelResourceReserve, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, RequestObjectSpillage, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, ReleaseUnusedBundles, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, GetSystemConfig, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, ShutdownRaylet, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, GetTasksInfo, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, GetObjectsInfo, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, GetTaskFailureCause, -1) +#define RAY_NODE_MANAGER_RPC_HANDLERS \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(UpdateResourceUsage) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(RequestResourceReport) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetResourceLoad) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(NotifyGCSRestart) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(RequestWorkerLease) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(ReportWorkerBacklog) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(ReturnWorker) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(ReleaseUnusedWorkers) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(CancelWorkerLease) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(PinObjectIDs) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetNodeStats) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GlobalGC) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(FormatGlobalMemoryInfo) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(PrepareBundleResources) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(CommitBundleResources) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(CancelResourceReserve) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(RequestObjectSpillage) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(ReleaseUnusedBundles) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetSystemConfig) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(ShutdownRaylet) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetTasksInfo) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetObjectsInfo) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetTaskFailureCause) /// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`. class NodeManagerServiceHandler { diff --git a/src/ray/rpc/object_manager/object_manager_server.h b/src/ray/rpc/object_manager/object_manager_server.h index ccc7543e443d..f22911767db5 100644 --- a/src/ray/rpc/object_manager/object_manager_server.h +++ b/src/ray/rpc/object_manager/object_manager_server.h @@ -23,10 +23,13 @@ namespace ray { namespace rpc { -#define RAY_OBJECT_MANAGER_RPC_HANDLERS \ - RPC_SERVICE_HANDLER(ObjectManagerService, Push, -1) \ - RPC_SERVICE_HANDLER(ObjectManagerService, Pull, -1) \ - RPC_SERVICE_HANDLER(ObjectManagerService, FreeObjects, -1) +#define RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(METHOD) \ + RPC_SERVICE_HANDLER_CUSTOM_AUTH(ObjectManagerService, METHOD, -1, AuthType::NO_AUTH) + +#define RAY_OBJECT_MANAGER_RPC_HANDLERS \ + RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(Push) \ + RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(Pull) \ + RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(FreeObjects) /// Implementations of the `ObjectManagerGrpcService`, check interface in /// `src/ray/protobuf/object_manager.proto`. diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index b6f42b391acd..829fd7eecbda 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -29,6 +29,13 @@ namespace ray { namespace rpc { +// Authentication type of ServerCall. +enum class AuthType { + NO_AUTH, // Do not authenticate (accept all). + LAZY_AUTH, // Accept missing token, but reject wrong token. + STRICT_AUTH, // Reject missing token and wrong token. +}; + /// Get the thread pool for the gRPC server. /// This pool is shared across gRPC servers. boost::asio::thread_pool &GetServerCallExecutor(); @@ -137,7 +144,10 @@ using HandleRequestFunction = void (ServiceHandler::*)(Request, /// \tparam ServiceHandler Type of the handler that handles the request. /// \tparam Request Type of the request message. /// \tparam Reply Type of the reply message. -template +template class ServerCallImpl : public ServerCall { public: /// Constructor. @@ -181,21 +191,54 @@ class ServerCallImpl : public ServerCall { void SetState(const ServerCallState &new_state) override { state_ = new_state; } void HandleRequest() override { + bool auth_success = true; + if constexpr (EnableAuth == AuthType::STRICT_AUTH) { + RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; + auto &metadata = context_.client_metadata(); + if (auto it = metadata.find(kClusterIdKey); + it == metadata.end() || it->second != cluster_id_.Hex()) { + RAY_LOG(DEBUG) << "Wrong cluster ID token in request! Expected: " + << cluster_id_.Hex() << ", but got: " + << (it == metadata.end() ? "No token!" : it->second); + auth_success = false; + } + } else if constexpr (EnableAuth == AuthType::LAZY_AUTH) { + RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; + auto &metadata = context_.client_metadata(); + if (auto it = metadata.find(kClusterIdKey); + it != metadata.end() && it->second != cluster_id_.Hex()) { + RAY_LOG(DEBUG) << "Wrong cluster ID token in request! Expected: " + << cluster_id_.Hex() << ", but got: " + << (it == metadata.end() ? "No token!" : it->second); + auth_success = false; + } + } else { + if (!cluster_id_.IsNil()) { + RAY_LOG_EVERY_N(WARNING, 100) + << "Unexpected cluster ID in server call! " << cluster_id_; + } + } + start_time_ = absl::GetCurrentTimeNanos(); if (record_metrics_) { ray::stats::STATS_grpc_server_req_handling.Record(1.0, call_name_); } if (!io_service_.stopped()) { - io_service_.post([this] { HandleRequestImpl(); }, call_name_); + io_service_.post([this, auth_success] { HandleRequestImpl(auth_success); }, + call_name_); } else { // Handle service for rpc call has stopped, we must handle the call here // to send reply and remove it from cq RAY_LOG(DEBUG) << "Handle service has been closed."; - SendReply(Status::Invalid("HandleServiceClosed")); + if (auth_success) { + SendReply(Status::Invalid("HandleServiceClosed")); + } else { + SendReply(Status::AuthError("WrongClusterToken")); + } } } - void HandleRequestImpl() { + void HandleRequestImpl(bool auth_success) { state_ = ServerCallState::PROCESSING; // NOTE(hchen): This `factory` local variable is needed. Because `SendReply` runs in // a different thread, and will cause `this` to be deleted. @@ -207,18 +250,24 @@ class ServerCallImpl : public ServerCall { // a new request comes in. factory.CreateCall(); } - (service_handler_.*handle_request_function_)( - std::move(request_), - reply_, - [this]( - Status status, std::function success, std::function failure) { - // These two callbacks must be set before `SendReply`, because `SendReply` - // is async and this `ServerCall` might be deleted right after `SendReply`. - send_reply_success_callback_ = std::move(success); - send_reply_failure_callback_ = std::move(failure); - boost::asio::post(GetServerCallExecutor(), - [this, status]() { SendReply(status); }); - }); + if (!auth_success) { + boost::asio::post(GetServerCallExecutor(), + [this]() { SendReply(Status::AuthError("WrongClusterToken")); }); + } else { + (service_handler_.*handle_request_function_)( + std::move(request_), + reply_, + [this](Status status, + std::function success, + std::function failure) { + // These two callbacks must be set before `SendReply`, because `SendReply` + // is async and this `ServerCall` might be deleted right after `SendReply`. + send_reply_success_callback_ = std::move(success); + send_reply_failure_callback_ = std::move(failure); + boost::asio::post(GetServerCallExecutor(), + [this, status]() { SendReply(status); }); + }); + } } void OnReplySent() override { @@ -318,7 +367,7 @@ class ServerCallImpl : public ServerCall { /// If true, the server call will generate gRPC server metrics. bool record_metrics_; - template + template friend class ServerCallFactoryImpl; }; @@ -342,7 +391,11 @@ using RequestCallFunction = /// \tparam ServiceHandler Type of the handler that handles the request. /// \tparam Request Type of the request message. /// \tparam Reply Type of the reply message. -template +template class ServerCallFactoryImpl : public ServerCallFactory { using AsyncService = typename GrpcService::AsyncService; @@ -385,14 +438,14 @@ class ServerCallFactoryImpl : public ServerCallFactory { void CreateCall() const override { // Create a new `ServerCall`. This object will eventually be deleted by // `GrpcServer::PollEventsFromCompletionQueue`. - auto call = - new ServerCallImpl(*this, - service_handler_, - handle_request_function_, - io_service_, - call_name_, - cluster_id_, - record_metrics_); + auto call = new ServerCallImpl( + *this, + service_handler_, + handle_request_function_, + io_service_, + call_name_, + cluster_id_, + record_metrics_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. (service_.*request_call_function_)(&call->context_, diff --git a/src/ray/rpc/test/grpc_bench/grpc_bench.cc b/src/ray/rpc/test/grpc_bench/grpc_bench.cc index c86fd9fbb912..043884a8f6be 100644 --- a/src/ray/rpc/test/grpc_bench/grpc_bench.cc +++ b/src/ray/rpc/test/grpc_bench/grpc_bench.cc @@ -53,7 +53,8 @@ class GreeterGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id) override{ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(Greeter, SayHello, -1)} + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( + Greeter, SayHello, -1, AuthType::NO_AUTH)} /// The grpc async service object. Greeter::AsyncService service_; diff --git a/src/ray/rpc/test/grpc_server_client_test.cc b/src/ray/rpc/test/grpc_server_client_test.cc index 5670725437da..f3ff712efe78 100644 --- a/src/ray/rpc/test/grpc_server_client_test.cc +++ b/src/ray/rpc/test/grpc_server_client_test.cc @@ -86,8 +86,10 @@ class TestGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id) override { - RPC_SERVICE_HANDLER(TestService, Ping, /*max_active_rpcs=*/1); - RPC_SERVICE_HANDLER(TestService, PingTimeout, /*max_active_rpcs=*/1); + RPC_SERVICE_HANDLER_CUSTOM_AUTH( + TestService, Ping, /*max_active_rpcs=*/1, AuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH( + TestService, PingTimeout, /*max_active_rpcs=*/1, AuthType::NO_AUTH); } private: diff --git a/src/ray/rpc/worker/core_worker_server.h b/src/ray/rpc/worker/core_worker_server.h index d6fc43dd2f9f..cb32397009ed 100644 --- a/src/ray/rpc/worker/core_worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -25,41 +25,36 @@ namespace ray { class CoreWorker; namespace rpc { +/// TODO(vitsai): Remove this when auth is implemented for node manager +#define RAY_CORE_WORKER_RPC_SERVICE_HANDLER(METHOD) \ + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( \ + CoreWorkerService, METHOD, -1, AuthType::NO_AUTH) /// NOTE: See src/ray/core_worker/core_worker.h on how to add a new grpc handler. /// Disable gRPC server metrics since it incurs too high cardinality. -#define RAY_CORE_WORKER_RPC_HANDLERS \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, PushTask, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, DirectActorCallArgWaitComplete, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, RayletNotifyGCSRestart, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, GetObjectStatus, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, WaitForActorOutOfScope, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, PubsubLongPolling, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, PubsubCommandBatch, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, UpdateObjectLocationBatch, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, GetObjectLocationsOwner, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, ReportGeneratorItemReturns, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, KillActor, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, CancelTask, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, RemoteCancelTask, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, GetCoreWorkerStats, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, LocalGC, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, DeleteObjects, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, SpillObjects, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, RestoreSpilledObjects, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, DeleteSpilledObjects, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, PlasmaObjectReady, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, Exit, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, AssignObjectOwner, -1) \ - RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, NumPendingTasks, -1) +#define RAY_CORE_WORKER_RPC_HANDLERS \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(PushTask) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(DirectActorCallArgWaitComplete) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(RayletNotifyGCSRestart) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(GetObjectStatus) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(WaitForActorOutOfScope) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(PubsubLongPolling) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(PubsubCommandBatch) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(UpdateObjectLocationBatch) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(GetObjectLocationsOwner) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(KillActor) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(CancelTask) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(RemoteCancelTask) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(GetCoreWorkerStats) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(LocalGC) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(DeleteObjects) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(SpillObjects) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(RestoreSpilledObjects) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(DeleteSpilledObjects) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(PlasmaObjectReady) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(Exit) \ + RAY_CORE_WORKER_RPC_SERVICE_HANDLER(AssignObjectOwner) + #define RAY_CORE_WORKER_DECLARE_RPC_HANDLERS \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PushTask) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DirectActorCallArgWaitComplete) \