Skip to content

Commit

Permalink
[core] Introducing InstrumentedIOContextWithThread. (#47831)
Browse files Browse the repository at this point in the history
Previously we had several ad-hoc places to do a "thread and io_context"
pattern: create a thread dedicated to an asio io_context, then workload
can post async tasks onto it. This makes duplicate code: everywhere we
create threads, implement stop and join.

Introducing InstrumentedIOContextWithThread that does exactly this and
replaces existing usages.

Also fixes some absl::Time computations with best practice.

This is refactoring. Should have no runtime difference.

Signed-off-by: Ruiyang Wang <rywang014@gmail.com>
  • Loading branch information
rynewang committed Sep 26, 2024
1 parent 3285452 commit 6b44557
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 171 deletions.
46 changes: 46 additions & 0 deletions src/ray/common/asio/asio_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

#include <boost/asio.hpp>
#include <chrono>
#include <thread>

#include "ray/common/asio/instrumented_io_context.h"
#include "ray/util/util.h"

template <typename Duration>
std::shared_ptr<boost::asio::deadline_timer> execute_after(
Expand All @@ -37,3 +39,47 @@ std::shared_ptr<boost::asio::deadline_timer> execute_after(

return timer;
}

/**
* A class that manages an instrumented_io_context and a std::thread.
* The constructor takes a thread name and starts the thread.
* The destructor stops the io_service and joins the thread.
*/
class InstrumentedIOContextWithThread {
public:
/**
* Constructor.
* @param thread_name The name of the thread.
*/
explicit InstrumentedIOContextWithThread(const std::string &thread_name)
: io_service_(), work_(io_service_) {
io_thread_ = std::thread([this, thread_name] {
SetThreadName(thread_name);
io_service_.run();
});
}

~InstrumentedIOContextWithThread() { Stop(); }

// Non-movable and non-copyable.
InstrumentedIOContextWithThread(const InstrumentedIOContextWithThread &) = delete;
InstrumentedIOContextWithThread &operator=(const InstrumentedIOContextWithThread &) =
delete;
InstrumentedIOContextWithThread(InstrumentedIOContextWithThread &&) = delete;
InstrumentedIOContextWithThread &operator=(InstrumentedIOContextWithThread &&) = delete;

instrumented_io_context &GetIoService() { return io_service_; }

// Idempotent. Once it's stopped you can't restart it.
void Stop() {
io_service_.stop();
if (io_thread_.joinable()) {
io_thread_.join();
}
}

private:
instrumented_io_context io_service_;
boost::asio::io_service::work work_; // to keep io_service_ running
std::thread io_thread_;
};
34 changes: 3 additions & 31 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <thread>
#include <utility>

#include "ray/common/asio/asio_util.h"
#include "ray/common/ray_config.h"
#include "ray/gcs/gcs_client/accessor.h"
#include "ray/pubsub/subscriber.h"
Expand Down Expand Up @@ -717,38 +718,9 @@ std::unordered_map<std::string, std::string> PythonGetNodeLabels(
node_info.labels().end());
}

/// Creates a singleton thread that runs an io_service.
/// All ConnectToGcsStandalone calls will share this io_service.
class SingletonIoContext {
public:
static SingletonIoContext &Instance() {
static SingletonIoContext instance;
return instance;
}

instrumented_io_context &GetIoService() { return io_service_; }

private:
SingletonIoContext() : work_(io_service_) {
io_thread_ = std::thread([this] {
SetThreadName("singleton_io_context.gcs_client");
io_service_.run();
});
}
~SingletonIoContext() {
io_service_.stop();
if (io_thread_.joinable()) {
io_thread_.join();
}
}

instrumented_io_context io_service_;
boost::asio::io_service::work work_; // to keep io_service_ running
std::thread io_thread_;
};

Status ConnectOnSingletonIoContext(GcsClient &gcs_client, int64_t timeout_ms) {
instrumented_io_context &io_service = SingletonIoContext::Instance().GetIoService();
static InstrumentedIOContextWithThread io_context("gcs_client_io_service");
instrumented_io_context &io_service = io_context.GetIoService();
return gcs_client.Connect(io_service, timeout_ms);
}

Expand Down
95 changes: 56 additions & 39 deletions src/ray/gcs/gcs_server/gcs_job_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,22 +247,6 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
// entrypoint script calls ray.init() multiple times).
std::unordered_map<std::string, std::vector<int>> job_data_key_to_indices;

// Create a shared counter for the number of jobs processed
std::shared_ptr<int> num_processed_jobs = std::make_shared<int>(0);

// Create a shared boolean flag for the internal KV callback completion
std::shared_ptr<bool> kv_callback_done = std::make_shared<bool>(false);

// Function to send the reply once all jobs have been processed and KV callback
// completed
auto try_send_reply =
[num_processed_jobs, kv_callback_done, reply, send_reply_callback]() {
if (*num_processed_jobs == reply->job_info_list_size() && *kv_callback_done) {
RAY_LOG(DEBUG) << "Finished getting all job info.";
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
}
};

// Load the job table data into the reply.
int i = 0;
for (auto &data : result) {
Expand All @@ -286,28 +270,64 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
job_api_data_keys.push_back(job_data_key);
job_data_key_to_indices[job_data_key].push_back(i);
}
i++;
}

if (!request.skip_is_running_tasks_field()) {
JobID job_id = data.first;
WorkerID worker_id =
WorkerID::FromBinary(data.second.driver_address().worker_id());
// Jobs are filtered. Now, optionally populate is_running_tasks and job_info. We
// do async calls to:
//
// - N outbound RPCs, one to each jobs' core workers on GcsServer::main_service_.
// - One InternalKV MultiGet call on GcsServer::kv_service_.
//
// And then we wait all by examining an atomic num_finished_tasks counter and then
// reply. The wait counter is written from 2 different thread, which requires an
// atomic read-and-increment. Each thread performs read-and-increment, and check
// the atomic readout to ensure try_send_reply is executed exactly once.

// Atomic counter of pending async tasks before sending the reply.
// Once it reaches total_tasks, the reply is sent.
std::shared_ptr<std::atomic<size_t>> num_finished_tasks =
std::make_shared<std::atomic<size_t>>(0);

// N tasks for N jobs; and 1 task for the MultiKVGet. If either is skipped the counter
// still increments.
const size_t total_tasks = reply->job_info_list_size() + 1;
auto try_send_reply =
[reply, send_reply_callback, total_tasks](size_t finished_tasks) {
if (finished_tasks == total_tasks) {
RAY_LOG(DEBUG) << "Finished getting all job info.";
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
}
};

// If job is not dead, get is_running_tasks from the core worker for the driver.
if (data.second.is_dead()) {
if (request.skip_is_running_tasks_field()) {
// Skipping RPCs to workers, just mark all job tasks as done.
const size_t job_count = reply->job_info_list_size();
size_t updated_finished_tasks =
num_finished_tasks->fetch_add(job_count) + job_count;
try_send_reply(updated_finished_tasks);
} else {
for (int i = 0; i < reply->job_info_list_size(); i++) {
const auto &data = reply->job_info_list(i);
auto job_id = JobID::FromBinary(data.job_id());
WorkerID worker_id = WorkerID::FromBinary(data.driver_address().worker_id());

// If job is dead, no need to get.
if (data.is_dead()) {
reply->mutable_job_info_list(i)->set_is_running_tasks(false);
core_worker_clients_.Disconnect(worker_id);
(*num_processed_jobs)++;
try_send_reply();
size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1;
try_send_reply(updated_finished_tasks);
} else {
// Get is_running_tasks from the core worker for the driver.
auto client = core_worker_clients_.GetOrConnect(data.second.driver_address());
auto client = core_worker_clients_.GetOrConnect(data.driver_address());
auto request = std::make_unique<rpc::NumPendingTasksRequest>();
constexpr int64_t kNumPendingTasksRequestTimeoutMs = 1000;
RAY_LOG(DEBUG) << "Send NumPendingTasksRequest to worker " << worker_id
<< ", timeout " << kNumPendingTasksRequestTimeoutMs << " ms.";
client->NumPendingTasks(
std::move(request),
[job_id, worker_id, reply, i, num_processed_jobs, try_send_reply](
[job_id, worker_id, reply, i, num_finished_tasks, try_send_reply](
const Status &status,
const rpc::NumPendingTasksReply &num_pending_tasks_reply) {
RAY_LOG(DEBUG).WithField(worker_id)
Expand All @@ -321,25 +341,25 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
bool is_running_tasks = num_pending_tasks_reply.num_pending_tasks() > 0;
reply->mutable_job_info_list(i)->set_is_running_tasks(is_running_tasks);
}
(*num_processed_jobs)++;
try_send_reply();
size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1;
try_send_reply(updated_finished_tasks);
},
kNumPendingTasksRequestTimeoutMs);
}
} else {
(*num_processed_jobs)++;
try_send_reply();
}
i++;
}

if (!request.skip_submission_job_info_field()) {
if (request.skip_submission_job_info_field()) {
// Skipping MultiKVGet, just mark the counter.
size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1;
try_send_reply(updated_finished_tasks);
} else {
// Load the JobInfo for jobs submitted via the Ray Job API.
auto kv_multi_get_callback =
[reply,
send_reply_callback,
job_data_key_to_indices,
kv_callback_done,
num_finished_tasks,
try_send_reply](std::unordered_map<std::string, std::string> &&result) {
for (const auto &data : result) {
const std::string &job_data_key = data.first;
Expand All @@ -362,13 +382,10 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
}
}
}
*kv_callback_done = true;
try_send_reply();
size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1;
try_send_reply(updated_finished_tasks);
};
internal_kv_.MultiGet("job", job_api_data_keys, kv_multi_get_callback);
} else {
*kv_callback_done = true;
try_send_reply();
}
};
Status status = gcs_table_storage_->JobTable().GetAll(on_done);
Expand Down
48 changes: 26 additions & 22 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config,
: config_(config),
storage_type_(GetStorageType()),
main_service_(main_service),
pubsub_io_context_("pubsub_io_context"),
task_io_context_("task_io_context"),
ray_syncer_io_context_("ray_syncer_io_context"),
rpc_server_(config.grpc_server_name,
config.grpc_server_port,
config.node_ip_address == "127.0.0.1",
Expand All @@ -65,7 +68,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config,
RayConfig::instance().gcs_server_rpc_client_thread_num()),
raylet_client_pool_(
std::make_shared<rpc::NodeManagerClientPool>(client_call_manager_)),
pubsub_periodical_runner_(pubsub_io_service_),
pubsub_periodical_runner_(pubsub_io_context_.GetIoService()),
periodical_runner_(main_service),
is_started_(false),
is_stopped_(false) {
Expand Down Expand Up @@ -264,13 +267,12 @@ void GcsServer::DoStart(const GcsInitData &gcs_init_data) {
void GcsServer::Stop() {
if (!is_stopped_) {
RAY_LOG(INFO) << "Stopping GCS server.";
ray_syncer_io_context_.stop();
ray_syncer_thread_->join();
ray_syncer_.reset();

gcs_task_manager_->Stop();
ray_syncer_io_context_.Stop();
task_io_context_.Stop();
pubsub_io_context_.Stop();

pubsub_handler_->Stop();
ray_syncer_.reset();
pubsub_handler_.reset();

// Shutdown the rpc server
Expand Down Expand Up @@ -531,16 +533,12 @@ GcsServer::StorageType GcsServer::GetStorageType() const {
}

void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) {
ray_syncer_ =
std::make_unique<syncer::RaySyncer>(ray_syncer_io_context_, kGCSNodeID.Binary());
ray_syncer_ = std::make_unique<syncer::RaySyncer>(ray_syncer_io_context_.GetIoService(),
kGCSNodeID.Binary());
ray_syncer_->Register(
syncer::MessageType::RESOURCE_VIEW, nullptr, gcs_resource_manager_.get());
ray_syncer_->Register(
syncer::MessageType::COMMANDS, nullptr, gcs_resource_manager_.get());
ray_syncer_thread_ = std::make_unique<std::thread>([this]() {
boost::asio::io_service::work work(ray_syncer_io_context_);
ray_syncer_io_context_.run();
});
ray_syncer_service_ = std::make_unique<syncer::RaySyncerService>(*ray_syncer_);
rpc_server_.RegisterService(*ray_syncer_service_);
}
Expand Down Expand Up @@ -587,10 +585,10 @@ void GcsServer::InitKVService() {
}

void GcsServer::InitPubSubHandler() {
pubsub_handler_ =
std::make_unique<InternalPubSubHandler>(pubsub_io_service_, gcs_publisher_);
pubsub_service_ = std::make_unique<rpc::InternalPubSubGrpcService>(pubsub_io_service_,
*pubsub_handler_);
pubsub_handler_ = std::make_unique<InternalPubSubHandler>(
pubsub_io_context_.GetIoService(), gcs_publisher_);
pubsub_service_ = std::make_unique<rpc::InternalPubSubGrpcService>(
pubsub_io_context_.GetIoService(), *pubsub_handler_);
// Register service.
rpc_server_.RegisterService(*pubsub_service_);
}
Expand Down Expand Up @@ -684,10 +682,10 @@ void GcsServer::InitGcsAutoscalerStateManager(const GcsInitData &gcs_init_data)
}

void GcsServer::InitGcsTaskManager() {
gcs_task_manager_ = std::make_unique<GcsTaskManager>();
gcs_task_manager_ = std::make_unique<GcsTaskManager>(task_io_context_.GetIoService());
// Register service.
task_info_service_.reset(new rpc::TaskInfoGrpcService(gcs_task_manager_->GetIoContext(),
*gcs_task_manager_));
task_info_service_.reset(
new rpc::TaskInfoGrpcService(task_io_context_.GetIoService(), *gcs_task_manager_));
rpc_server_.RegisterService(*task_info_service_);
}

Expand Down Expand Up @@ -841,9 +839,15 @@ void GcsServer::PrintAsioStats() {
const auto event_stats_print_interval_ms =
RayConfig::instance().event_stats_print_interval_ms();
if (event_stats_print_interval_ms != -1 && RayConfig::instance().event_stats()) {
RAY_LOG(INFO) << "Event stats:\n\n" << main_service_.stats().StatsString() << "\n\n";
RAY_LOG(INFO) << "GcsTaskManager Event stats:\n\n"
<< gcs_task_manager_->GetIoContext().stats().StatsString() << "\n\n";
RAY_LOG(INFO) << "main_service_ Event stats:\n\n"
<< main_service_.stats().StatsString() << "\n\n";
RAY_LOG(INFO) << "pubsub_io_context_ Event stats:\n\n"
<< pubsub_io_context_.GetIoService().stats().StatsString() << "\n\n";
RAY_LOG(INFO) << "task_io_context_ Event stats:\n\n"
<< task_io_context_.GetIoService().stats().StatsString() << "\n\n";
RAY_LOG(INFO) << "ray_syncer_io_context_ Event stats:\n\n"
<< ray_syncer_io_context_.GetIoService().stats().StatsString()
<< "\n\n";
}
}

Expand Down
9 changes: 6 additions & 3 deletions src/ray/gcs/gcs_server/gcs_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "ray/common/asio/asio_util.h"
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/common/ray_syncer/ray_syncer.h"
#include "ray/common/runtime_env_manager.h"
Expand Down Expand Up @@ -212,7 +213,11 @@ class GcsServer {
/// The main io service to drive event posted from grpc threads.
instrumented_io_context &main_service_;
/// The io service used by Pubsub, for isolation from other workload.
instrumented_io_context pubsub_io_service_;
InstrumentedIOContextWithThread pubsub_io_context_;
// The io service used by task manager.
InstrumentedIOContextWithThread task_io_context_;
// The io service used by ray syncer.
InstrumentedIOContextWithThread ray_syncer_io_context_;
/// The grpc server
rpc::GrpcServer rpc_server_;
/// The `ClientCallManager` object that is shared by all `NodeManagerWorkerClient`s.
Expand Down Expand Up @@ -254,8 +259,6 @@ class GcsServer {
/// Ray Syncer related fields.
std::unique_ptr<syncer::RaySyncer> ray_syncer_;
std::unique_ptr<syncer::RaySyncerService> ray_syncer_service_;
std::unique_ptr<std::thread> ray_syncer_thread_;
instrumented_io_context ray_syncer_io_context_;

/// The node id of GCS.
NodeID gcs_node_id_;
Expand Down
Loading

0 comments on commit 6b44557

Please sign in to comment.