Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[core] Retrieve the token from GCS server [4/n] (#37003)" #37399

Merged
merged 1 commit into from
Jul 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Revert "[core] Retrieve the token from GCS server [4/n] (#37003) (#37294
)"

This reverts commit 456f532.
vitsai authored Jul 13, 2023
commit c16c6792beaab22a45daa60aef30893913ff7f5c
5 changes: 1 addition & 4 deletions src/mock/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
@@ -31,10 +31,7 @@ namespace gcs {

class MockGcsClient : public GcsClient {
public:
MOCK_METHOD(Status,
Connect,
(instrumented_io_context & io_service, const ClusterID &cluster_id),
(override));
MOCK_METHOD(Status, Connect, (instrumented_io_context & io_service), (override));
MOCK_METHOD(void, Disconnect, (), (override));
MOCK_METHOD((std::pair<std::string, int>), GetGcsServerAddress, (), (const, override));
MOCK_METHOD(std::string, DebugString, (), (const, override));
42 changes: 1 addition & 41 deletions src/ray/common/asio/instrumented_io_context.h
Original file line number Diff line number Diff line change
@@ -28,45 +28,7 @@ class instrumented_io_context : public boost::asio::io_context {
public:
/// Initializes the global stats struct after calling the base contructor.
/// TODO(ekl) allow taking an externally defined event tracker.
instrumented_io_context()
: is_running_{false}, event_stats_(std::make_shared<EventTracker>()) {}

/// Run the io_context if and only if no other thread is running it. Blocks
/// other threads from running once started. Noop if there is a thread
/// running it already. Only used in GcsClient::Connect, to be deprecated
/// after the introduction of executors.
void run_if_not_running(std::function<void()> prerun_fn) {
absl::MutexLock l(&mu_);
// Note: this doesn't set is_running_ because it blocks anything else from
// running anyway.
if (!is_running_) {
prerun_fn();
is_running_ = true;
boost::asio::io_context::run();
}
}

/// Assumes the mutex is held. Undefined behavior if not.
void stop_without_lock() {
is_running_ = false;
boost::asio::io_context::stop();
}

void run() {
{
absl::MutexLock l(&mu_);
is_running_ = true;
}
boost::asio::io_context::run();
}

void stop() {
{
absl::MutexLock l(&mu_);
is_running_ = false;
}
boost::asio::io_context::stop();
}
instrumented_io_context() : event_stats_(std::make_shared<EventTracker>()) {}

/// A proxy post function that collects count, queueing, and execution statistics for
/// the given handler.
@@ -94,8 +56,6 @@ class instrumented_io_context : public boost::asio::io_context {
EventTracker &stats() const { return *event_stats_; };

private:
absl::Mutex mu_;
bool is_running_;
/// The event stats tracker to use to record asio handler stats to.
std::shared_ptr<EventTracker> event_stats_;
};
21 changes: 21 additions & 0 deletions src/ray/common/id.h
Original file line number Diff line number Diff line change
@@ -414,6 +414,27 @@ std::ostream &operator<<(std::ostream &os, const PlacementGroupID &id);
// Restore the compiler alignment to default (8 bytes).
#pragma pack(pop)

struct SafeClusterID {
private:
mutable absl::Mutex m_;
ClusterID id_ GUARDED_BY(m_);

public:
SafeClusterID(const ClusterID &id) : id_(id) {}

const ClusterID load() const {
absl::MutexLock l(&m_);
return id_;
}

ClusterID exchange(const ClusterID &newId) {
absl::MutexLock l(&m_);
ClusterID old = id_;
id_ = newId;
return old;
}
};

template <typename T>
BaseID<T>::BaseID() {
// Using const_cast to directly change data is dangerous. The cached
34 changes: 1 addition & 33 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
@@ -81,44 +81,12 @@ void GcsSubscriberClient::PubsubCommandBatch(
GcsClient::GcsClient(const GcsClientOptions &options, UniqueID gcs_client_id)
: options_(options), gcs_client_id_(gcs_client_id) {}

Status GcsClient::Connect(instrumented_io_context &io_service,
const ClusterID &cluster_id) {
Status GcsClient::Connect(instrumented_io_context &io_service) {
// Connect to gcs service.
client_call_manager_ = std::make_unique<rpc::ClientCallManager>(io_service);
gcs_rpc_client_ = std::make_shared<rpc::GcsRpcClient>(
options_.gcs_address_, options_.gcs_port_, *client_call_manager_);

if (cluster_id.IsNil()) {
rpc::GetClusterIdReply reply;
std::promise<void> cluster_known;
std::atomic<bool> stop_io_service{false};
gcs_rpc_client_->GetClusterId(
rpc::GetClusterIdRequest(),
[this, &cluster_known, &stop_io_service, &io_service](
const Status &status, const rpc::GetClusterIdReply &reply) {
RAY_CHECK(status.ok()) << "Failed to get Cluster ID! Status: " << status;
auto cluster_id = ClusterID::FromBinary(reply.cluster_id());
RAY_LOG(DEBUG) << "Setting cluster ID to " << cluster_id;
client_call_manager_->SetClusterId(cluster_id);
if (stop_io_service.load()) {
io_service.stop_without_lock();
} else {
cluster_known.set_value();
}
});
// Run the IO service here to make the above call synchronous.
// If it is already running, then wait for our particular callback
// to be processed.
io_service.run_if_not_running([&stop_io_service]() { stop_io_service.store(true); });
if (stop_io_service.load()) {
io_service.restart();
} else {
cluster_known.get_future().get();
}
} else {
client_call_manager_->SetClusterId(cluster_id);
}

resubscribe_func_ = [this]() {
job_accessor_->AsyncResubscribe();
actor_accessor_->AsyncResubscribe();
11 changes: 1 addition & 10 deletions src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@

#pragma once

#include <gtest/gtest.h>
#include <gtest/gtest_prod.h>

#include <boost/asio.hpp>
@@ -36,9 +35,6 @@

namespace ray {

class GcsClientTest;
class GcsClientTest_TestCheckAlive_Test;

namespace gcs {

/// \class GcsClientOptions
@@ -86,11 +82,9 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
/// Connect to GCS Service. Non-thread safe.
/// This function must be called before calling other functions.
/// \param instrumented_io_context IO execution service.
/// \param cluster_id Optional cluster ID to provide to the client.
///
/// \return Status
virtual Status Connect(instrumented_io_context &io_service,
const ClusterID &cluster_id = ClusterID::Nil());
virtual Status Connect(instrumented_io_context &io_service);

/// Disconnect with GCS Service. Non-thread safe.
virtual void Disconnect();
@@ -182,9 +176,6 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
std::unique_ptr<InternalKVAccessor> internal_kv_accessor_;
std::unique_ptr<TaskInfoAccessor> task_accessor_;

friend class ray::GcsClientTest;
FRIEND_TEST(ray::GcsClientTest, TestCheckAlive);

private:
const UniqueID gcs_client_id_ = UniqueID::FromRandom();

33 changes: 2 additions & 31 deletions src/ray/gcs/gcs_client/test/gcs_client_test.cc
Original file line number Diff line number Diff line change
@@ -115,13 +115,6 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
rpc::ResetServerCallExecutor();
}

void StampContext(grpc::ClientContext &context) {
RAY_CHECK(gcs_client_->client_call_manager_)
<< "Cannot stamp context before initializing client call manager.";
context.AddMetadata(kClusterIdKey,
gcs_client_->client_call_manager_->cluster_id_.Hex());
}

void RestartGcsServer() {
RAY_LOG(INFO) << "Stopping GCS service, port = " << gcs_server_->GetPort();
gcs_server_->Stop();
@@ -148,17 +141,11 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
grpc::CreateChannel(absl::StrCat("127.0.0.1:", gcs_server_->GetPort()),
grpc::InsecureChannelCredentials());
auto stub = rpc::NodeInfoGcsService::NewStub(std::move(channel));
bool in_memory =
RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage;
grpc::ClientContext context;
if (!in_memory) {
StampContext(context);
}
context.set_deadline(std::chrono::system_clock::now() + 1s);
const rpc::CheckAliveRequest request;
rpc::CheckAliveReply reply;
auto status = stub->CheckAlive(&context, request, &reply);
// If it is in memory, we don't have the new token until we connect again.
if (!status.ok()) {
RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " "
<< status.error_message();
@@ -328,10 +315,8 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {

bool RegisterNode(const rpc::GcsNodeInfo &node_info) {
std::promise<bool> promise;
RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(node_info, [&promise](Status status) {
RAY_LOG(INFO) << status;
promise.set_value(status.ok());
}));
RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(
node_info, [&promise](Status status) { promise.set_value(status.ok()); }));
return WaitReady(promise.get_future(), timeout_ms_);
}

@@ -478,7 +463,6 @@ TEST_P(GcsClientTest, TestCheckAlive) {
*(request.mutable_raylet_address()->Add()) = "172.1.2.4:31293";
{
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
rpc::CheckAliveReply reply;
ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok());
@@ -490,7 +474,6 @@ TEST_P(GcsClientTest, TestCheckAlive) {
ASSERT_TRUE(RegisterNode(*node_info1));
{
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
rpc::CheckAliveReply reply;
ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok());
@@ -1004,21 +987,9 @@ TEST_P(GcsClientTest, TestEvictExpiredDestroyedActors) {
}
}

TEST_P(GcsClientTest, TestGcsAuth) {
// Restart GCS.
RestartGcsServer();
auto node_info = Mocker::GenNodeInfo();

RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_));
EXPECT_TRUE(RegisterNode(*node_info));
}

TEST_P(GcsClientTest, TestEvictExpiredDeadNodes) {
// Restart GCS.
RestartGcsServer();
if (RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage) {
RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_));
}

// Simulate the scenario of node dead.
int node_count = RayConfig::instance().maximum_gcs_dead_node_cached_count();
4 changes: 1 addition & 3 deletions src/ray/gcs/gcs_client/test/usage_stats_client_test.cc
Original file line number Diff line number Diff line change
@@ -82,9 +82,7 @@ class UsageStatsClientTest : public ::testing::Test {

TEST_F(UsageStatsClientTest, TestRecordExtraUsageTag) {
gcs::UsageStatsClient usage_stats_client(
"127.0.0.1:" + std::to_string(gcs_server_->GetPort()),
*client_io_service_,
ClusterID::Nil());
"127.0.0.1:" + std::to_string(gcs_server_->GetPort()), *client_io_service_);
usage_stats_client.RecordExtraUsageTag(usage::TagKey::_TEST1, "value1");
ASSERT_TRUE(WaitForCondition(
[this]() {
5 changes: 2 additions & 3 deletions src/ray/gcs/gcs_client/usage_stats_client.cc
Original file line number Diff line number Diff line change
@@ -17,11 +17,10 @@
namespace ray {
namespace gcs {
UsageStatsClient::UsageStatsClient(const std::string &gcs_address,
instrumented_io_context &io_service,
const ClusterID &cluster_id) {
instrumented_io_context &io_service) {
GcsClientOptions options(gcs_address);
gcs_client_ = std::make_unique<GcsClient>(options);
RAY_CHECK_OK(gcs_client_->Connect(io_service, cluster_id));
RAY_CHECK_OK(gcs_client_->Connect(io_service));
}

void UsageStatsClient::RecordExtraUsageTag(usage::TagKey key, const std::string &value) {
3 changes: 1 addition & 2 deletions src/ray/gcs/gcs_client/usage_stats_client.h
Original file line number Diff line number Diff line change
@@ -24,8 +24,7 @@ namespace gcs {
class UsageStatsClient {
public:
explicit UsageStatsClient(const std::string &gcs_address,
instrumented_io_context &io_service,
const ClusterID &cluster_id);
instrumented_io_context &io_service);

/// C++ version of record_extra_usage_tag in usage_lib.py
///
6 changes: 2 additions & 4 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
@@ -566,10 +566,8 @@ void GcsServer::InitFunctionManager() {
}

void GcsServer::InitUsageStatsClient() {
usage_stats_client_ =
std::make_unique<UsageStatsClient>("127.0.0.1:" + std::to_string(GetPort()),
main_service_,
rpc_server_.GetClusterId());
usage_stats_client_ = std::make_unique<UsageStatsClient>(
"127.0.0.1:" + std::to_string(GetPort()), main_service_);
}

void GcsServer::InitKVManager() {
Original file line number Diff line number Diff line change
@@ -101,8 +101,7 @@ class MockGcsClient : public gcs::GcsClient {
return *node_accessor_;
}

MOCK_METHOD2(Connect,
Status(instrumented_io_context &io_service, const ClusterID &cluster_id));
MOCK_METHOD1(Connect, Status(instrumented_io_context &io_service));

MOCK_METHOD0(Disconnect, void());
};
22 changes: 8 additions & 14 deletions src/ray/rpc/client_call.h
Original file line number Diff line number Diff line change
@@ -28,9 +28,6 @@
#include "ray/util/util.h"

namespace ray {

class GcsClientTest;
class GcsClientTest_TestCheckAlive_Test;
namespace rpc {

/// Represents an outgoing gRPC request.
@@ -148,10 +145,10 @@ class ClientCallImpl : public ClientCall {
/// The lifecycle of a `ClientCallTag` is as follows.
///
/// When a client submits a new gRPC request, a new `ClientCallTag` object will be created
/// by `ClientCallManager::CreateCall`. Then the object will be used as the tag of
/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of
/// `CompletionQueue`.
///
/// When the reply is received, `ClientCallManager` will get the address of this object
/// When the reply is received, `ClientCallMangager` will get the address of this object
/// via `CompletionQueue`'s tag. And the manager should call
/// `GetCall()->OnReplyReceived()` and then delete this object.
class ClientCallTag {
@@ -197,7 +194,7 @@ class ClientCallManager {
const ClusterID &cluster_id = ClusterID::Nil(),
int num_threads = 1,
int64_t call_timeout_ms = -1)
: cluster_id_(cluster_id),
: cluster_id_(ClusterID::Nil()),
main_service_(main_service),
num_threads_(num_threads),
shutdown_(false),
@@ -252,7 +249,7 @@ class ClientCallManager {
}

auto call = std::make_shared<ClientCallImpl<Reply>>(
callback, cluster_id_, std::move(stats_handle), method_timeout_ms);
callback, cluster_id_.load(), std::move(stats_handle), method_timeout_ms);
// Send request.
// Find the next completion queue to wait for response.
call->response_reader_ = (stub.*prepare_async_function)(
@@ -271,19 +268,16 @@ class ClientCallManager {
}

void SetClusterId(const ClusterID &cluster_id) {
if (!cluster_id_.IsNil() && (cluster_id_ != cluster_id)) {
auto old_id = cluster_id_.exchange(ClusterID::Nil());
if (!old_id.IsNil() && (old_id != cluster_id)) {
RAY_LOG(FATAL) << "Expected cluster ID to be Nil or " << cluster_id << ", but got"
<< cluster_id_;
<< old_id;
}
cluster_id_ = cluster_id;
}

/// Get the main service of this rpc.
instrumented_io_context &GetMainService() { return main_service_; }

friend class ray::GcsClientTest;
FRIEND_TEST(ray::GcsClientTest, TestCheckAlive);

private:
/// This function runs in a background thread. It keeps polling events from the
/// `CompletionQueue`, and dispatches the event to the callbacks via the `ClientCall`
@@ -334,7 +328,7 @@ class ClientCallManager {

/// UUID of the cluster. Potential race between creating a ClientCall object
/// and setting the cluster ID.
ClusterID cluster_id_;
SafeClusterID cluster_id_;

/// The main event loop, to which the callback functions will be posted.
instrumented_io_context &main_service_;