Skip to content

Commit

Permalink
atomics
Browse files Browse the repository at this point in the history
Signed-off-by: Ruiyang Wang <rywang014@gmail.com>
  • Loading branch information
rynewang committed Sep 24, 2024
1 parent a87c39d commit 7cd7705
Showing 1 changed file with 51 additions and 42 deletions.
93 changes: 51 additions & 42 deletions src/ray/gcs/gcs_server/gcs_job_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,25 +247,6 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
// entrypoint script calls ray.init() multiple times).
std::unordered_map<std::string, std::vector<int>> job_data_key_to_indices;

// Create a shared counter for the number of jobs processed.
// This is written in internal_kv_'s thread and read in the main thread.
std::shared_ptr<std::atomic<size_t>> num_processed_jobs =
std::make_shared<std::atomic<size_t>>(0);

// Create a shared boolean flag for the internal KV callback completion
std::shared_ptr<std::atomic<bool>> kv_callback_done =
std::make_shared<std::atomic<bool>>(false);

// Function to send the reply once all jobs have been processed and KV callback
// completed
auto try_send_reply =
[num_processed_jobs, kv_callback_done, reply, send_reply_callback]() {
if (*num_processed_jobs == reply->job_info_list_size() && *kv_callback_done) {
RAY_LOG(DEBUG) << "Finished getting all job info.";
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
}
};

// Load the job table data into the reply.
int i = 0;
for (auto &data : result) {
Expand All @@ -289,28 +270,59 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
job_api_data_keys.push_back(job_data_key);
job_data_key_to_indices[job_data_key].push_back(i);
}
i++;
}

// Jobs are filtered. Now, optionally populate is_running_tasks and job_info. A
// `asyncio.gather` is needed but we are in C++; so we use atomic counters.

if (!request.skip_is_running_tasks_field()) {
JobID job_id = data.first;
WorkerID worker_id =
WorkerID::FromBinary(data.second.driver_address().worker_id());
// Atomic counter of pending async tasks before sending the reply.
// Once it reaches total_tasks, the reply is sent.
std::shared_ptr<std::atomic<size_t>> num_finished_tasks =
std::make_shared<std::atomic<size_t>>(0);

// If job is not dead, get is_running_tasks from the core worker for the driver.
if (data.second.is_dead()) {
// N tasks for N jobs; and 1 task for the MultiKVGet. If either is skipped the counter
// still increments.
const size_t total_tasks = reply->job_info_list_size() + 1;

// Those async tasks need to atomically read-and-increment the counter, so this
// callback can't capture the atomic variable directly. Instead, it asks for an
// regular variable argument coming from the read-and-increment caller.
auto try_send_reply =
[reply, send_reply_callback, total_tasks](size_t finished_tasks) {
if (finished_tasks == total_tasks) {
RAY_LOG(DEBUG) << "Finished getting all job info.";
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
}
};

if (request.skip_is_running_tasks_field()) {
// Skipping RPCs to workers, just mark all job tasks as done.
const size_t job_count = reply->job_info_list_size();
size_t updated_finished_tasks =
num_finished_tasks->fetch_add(job_count) + job_count;
try_send_reply(updated_finished_tasks);
} else {
for (const auto &data : reply->job_info_list()) {
auto job_id = JobID::FromBinary(data.job_id());
WorkerID worker_id = WorkerID::FromBinary(data.driver_address().worker_id());

// If job is dead, no need to get.
if (data.is_dead()) {
reply->mutable_job_info_list(i)->set_is_running_tasks(false);
core_worker_clients_.Disconnect(worker_id);
(*num_processed_jobs)++;
try_send_reply();
size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1;
try_send_reply(updated_finished_tasks);
} else {
// Get is_running_tasks from the core worker for the driver.
auto client = core_worker_clients_.GetOrConnect(data.second.driver_address());
auto client = core_worker_clients_.GetOrConnect(data.driver_address());
auto request = std::make_unique<rpc::NumPendingTasksRequest>();
constexpr int64_t kNumPendingTasksRequestTimeoutMs = 1000;
RAY_LOG(DEBUG) << "Send NumPendingTasksRequest to worker " << worker_id
<< ", timeout " << kNumPendingTasksRequestTimeoutMs << " ms.";
client->NumPendingTasks(
std::move(request),
[job_id, worker_id, reply, i, num_processed_jobs, try_send_reply](
[job_id, worker_id, reply, i, num_finished_tasks, try_send_reply](
const Status &status,
const rpc::NumPendingTasksReply &num_pending_tasks_reply) {
RAY_LOG(DEBUG).WithField(worker_id)
Expand All @@ -324,25 +336,25 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
bool is_running_tasks = num_pending_tasks_reply.num_pending_tasks() > 0;
reply->mutable_job_info_list(i)->set_is_running_tasks(is_running_tasks);
}
(*num_processed_jobs)++;
try_send_reply();
size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1;
try_send_reply(updated_finished_tasks);
},
kNumPendingTasksRequestTimeoutMs);
}
} else {
(*num_processed_jobs)++;
try_send_reply();
}
i++;
}

if (!request.skip_submission_job_info_field()) {
if (request.skip_submission_job_info_field()) {
// Skipping MultiKVGet, just mark the counter.
size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1;
try_send_reply(updated_finished_tasks);
} else {
// Load the JobInfo for jobs submitted via the Ray Job API.
auto kv_multi_get_callback =
[reply,
send_reply_callback,
job_data_key_to_indices,
kv_callback_done,
num_finished_tasks,
try_send_reply](std::unordered_map<std::string, std::string> &&result) {
for (const auto &data : result) {
const std::string &job_data_key = data.first;
Expand All @@ -365,13 +377,10 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
}
}
}
*kv_callback_done = true;
try_send_reply();
size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1;
try_send_reply(updated_finished_tasks);
};
internal_kv_.MultiGet("job", job_api_data_keys, kv_multi_get_callback);
} else {
*kv_callback_done = true;
try_send_reply();
}
};
Status status = gcs_table_storage_->JobTable().GetAll(on_done);
Expand Down

0 comments on commit 7cd7705

Please sign in to comment.