Skip to content

Commit

Permalink
Add client-side stuff
Browse files Browse the repository at this point in the history
Signed-off-by: vitsai <victoria@anyscale.com>
  • Loading branch information
vitsai committed Jun 22, 2023
1 parent dda3b29 commit a1abc5c
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 146 deletions.
5 changes: 4 additions & 1 deletion src/mock/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ namespace gcs {

class MockGcsClient : public GcsClient {
public:
MOCK_METHOD(Status, Connect, (instrumented_io_context & io_service), (override));
MOCK_METHOD(Status,
Connect,
(instrumented_io_context & io_service, const ClusterID &cluster_id),
(override));
MOCK_METHOD(void, Disconnect, (), (override));
MOCK_METHOD((std::pair<std::string, int>), GetGcsServerAddress, (), (const, override));
MOCK_METHOD(std::string, DebugString, (), (const, override));
Expand Down
14 changes: 13 additions & 1 deletion src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,24 @@ void GcsSubscriberClient::PubsubCommandBatch(
GcsClient::GcsClient(const GcsClientOptions &options, UniqueID gcs_client_id)
: options_(options), gcs_client_id_(gcs_client_id) {}

Status GcsClient::Connect(instrumented_io_context &io_service) {
Status GcsClient::Connect(instrumented_io_context &io_service,
const ClusterID &cluster_id) {
// Connect to gcs service.
client_call_manager_ = std::make_unique<rpc::ClientCallManager>(io_service);
gcs_rpc_client_ = std::make_shared<rpc::GcsRpcClient>(
options_.gcs_address_, options_.gcs_port_, *client_call_manager_);

if (cluster_id.IsNil()) {
rpc::GetClusterIdReply reply;
RAY_CHECK(gcs_rpc_client_->SyncGetClusterId(rpc::GetClusterIdRequest(), &reply).ok())
<< "Failed to get Cluster ID!";
auto cluster_id = ClusterID::FromBinary(reply.cluster_id());
RAY_LOG(DEBUG) << "Setting cluster ID to " << cluster_id;
client_call_manager_->SetClusterId(cluster_id);
} else {
client_call_manager_->SetClusterId(cluster_id);
}

resubscribe_func_ = [this]() {
job_accessor_->AsyncResubscribe();
actor_accessor_->AsyncResubscribe();
Expand Down
11 changes: 10 additions & 1 deletion src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <gtest/gtest.h>
#include <gtest/gtest_prod.h>

#include <boost/asio.hpp>
Expand All @@ -34,6 +35,9 @@

namespace ray {

class GcsClientTest;
class GcsClientTest_TestCheckAlive_Test;

namespace gcs {

/// \class GcsClientOptions
Expand Down Expand Up @@ -81,9 +85,11 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
/// Connect to GCS Service. Non-thread safe.
/// This function must be called before calling other functions.
/// \param instrumented_io_context IO execution service.
/// \param cluster_id Optional cluster ID to provide to the client.
///
/// \return Status
virtual Status Connect(instrumented_io_context &io_service);
virtual Status Connect(instrumented_io_context &io_service,
const ClusterID &cluster_id = ClusterID::Nil());

/// Disconnect with GCS Service. Non-thread safe.
virtual void Disconnect();
Expand Down Expand Up @@ -175,6 +181,9 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
std::unique_ptr<InternalKVAccessor> internal_kv_accessor_;
std::unique_ptr<TaskInfoAccessor> task_accessor_;

friend class ray::GcsClientTest;
FRIEND_TEST(ray::GcsClientTest, TestCheckAlive);

private:
const UniqueID gcs_client_id_ = UniqueID::FromRandom();

Expand Down
33 changes: 31 additions & 2 deletions src/ray/gcs/gcs_client/test/gcs_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
rpc::ResetServerCallExecutor();
}

void StampContext(grpc::ClientContext &context) {
RAY_CHECK(gcs_client_->client_call_manager_)
<< "Cannot stamp context before initializing client call manager.";
context.AddMetadata(kClusterIdKey,
gcs_client_->client_call_manager_->cluster_id_.load().Hex());
}

void RestartGcsServer() {
RAY_LOG(INFO) << "Stopping GCS service, port = " << gcs_server_->GetPort();
gcs_server_->Stop();
Expand All @@ -141,11 +148,17 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
grpc::CreateChannel(absl::StrCat("127.0.0.1:", gcs_server_->GetPort()),
grpc::InsecureChannelCredentials());
auto stub = rpc::NodeInfoGcsService::NewStub(std::move(channel));
bool in_memory =
RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage;
grpc::ClientContext context;
if (!in_memory) {
StampContext(context);
}
context.set_deadline(std::chrono::system_clock::now() + 1s);
const rpc::CheckAliveRequest request;
rpc::CheckAliveReply reply;
auto status = stub->CheckAlive(&context, request, &reply);
// If it is in memory, we don't have the new token until we connect again.
if (!status.ok()) {
RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " "
<< status.error_message();
Expand Down Expand Up @@ -315,8 +328,10 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {

bool RegisterNode(const rpc::GcsNodeInfo &node_info) {
std::promise<bool> promise;
RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(
node_info, [&promise](Status status) { promise.set_value(status.ok()); }));
RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(node_info, [&promise](Status status) {
RAY_LOG(INFO) << status;
promise.set_value(status.ok());
}));
return WaitReady(promise.get_future(), timeout_ms_);
}

Expand Down Expand Up @@ -463,6 +478,7 @@ TEST_P(GcsClientTest, TestCheckAlive) {
*(request.mutable_raylet_address()->Add()) = "172.1.2.4:31293";
{
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
rpc::CheckAliveReply reply;
ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok());
Expand All @@ -474,6 +490,7 @@ TEST_P(GcsClientTest, TestCheckAlive) {
ASSERT_TRUE(RegisterNode(*node_info1));
{
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
rpc::CheckAliveReply reply;
ASSERT_TRUE(stub->CheckAlive(&context, request, &reply).ok());
Expand Down Expand Up @@ -987,9 +1004,21 @@ TEST_P(GcsClientTest, TestEvictExpiredDestroyedActors) {
}
}

TEST_P(GcsClientTest, TestGcsAuth) {
// Restart GCS.
RestartGcsServer();
auto node_info = Mocker::GenNodeInfo();

RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_));
EXPECT_TRUE(RegisterNode(*node_info));
}

TEST_P(GcsClientTest, TestEvictExpiredDeadNodes) {
// Restart GCS.
RestartGcsServer();
if (RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage) {
RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_));
}

// Simulate the scenario of node dead.
int node_count = RayConfig::instance().maximum_gcs_dead_node_cached_count();
Expand Down
4 changes: 3 additions & 1 deletion src/ray/gcs/gcs_client/test/usage_stats_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class UsageStatsClientTest : public ::testing::Test {

TEST_F(UsageStatsClientTest, TestRecordExtraUsageTag) {
gcs::UsageStatsClient usage_stats_client(
"127.0.0.1:" + std::to_string(gcs_server_->GetPort()), *client_io_service_);
"127.0.0.1:" + std::to_string(gcs_server_->GetPort()),
*client_io_service_,
ClusterID::Nil());
usage_stats_client.RecordExtraUsageTag(usage::TagKey::_TEST1, "value1");
ASSERT_TRUE(WaitForCondition(
[this]() {
Expand Down
5 changes: 3 additions & 2 deletions src/ray/gcs/gcs_client/usage_stats_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
namespace ray {
namespace gcs {
UsageStatsClient::UsageStatsClient(const std::string &gcs_address,
instrumented_io_context &io_service) {
instrumented_io_context &io_service,
const ClusterID &cluster_id) {
GcsClientOptions options(gcs_address);
gcs_client_ = std::make_unique<GcsClient>(options);
RAY_CHECK_OK(gcs_client_->Connect(io_service));
RAY_CHECK_OK(gcs_client_->Connect(io_service, cluster_id));
}

void UsageStatsClient::RecordExtraUsageTag(usage::TagKey key, const std::string &value) {
Expand Down
3 changes: 2 additions & 1 deletion src/ray/gcs/gcs_client/usage_stats_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace gcs {
class UsageStatsClient {
public:
explicit UsageStatsClient(const std::string &gcs_address,
instrumented_io_context &io_service);
instrumented_io_context &io_service,
const ClusterID &cluster_id);

/// C++ version of record_extra_usage_tag in usage_lib.py
///
Expand Down
6 changes: 4 additions & 2 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,10 @@ void GcsServer::InitUsageStatsClient() {
// Note: We pass in cluster_id here to avoid deadlock during server init.
// This can occur since main_service_ is not started, and so the GetClusterId RPC from
// GCS client inside the UsageStatsClient will not be answered by GCS server.
usage_stats_client_ = std::make_unique<UsageStatsClient>(
"127.0.0.1:" + std::to_string(GetPort()), main_service_);
usage_stats_client_ =
std::make_unique<UsageStatsClient>("127.0.0.1:" + std::to_string(GetPort()),
main_service_,
rpc_server_.GetClusterId());
}

void GcsServer::InitKVManager() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class MockGcsClient : public gcs::GcsClient {
return *node_accessor_;
}

MOCK_METHOD1(Connect, Status(instrumented_io_context &io_service));
MOCK_METHOD2(Connect,
Status(instrumented_io_context &io_service, const ClusterID &cluster_id));

MOCK_METHOD0(Disconnect, void());
};
Expand Down
Loading

0 comments on commit a1abc5c

Please sign in to comment.