Skip to content

Commit

Permalink
[core] [1/N] New interface to combine client registerat and announce …
Browse files Browse the repository at this point in the history
…at NodeManager (#49235)

Issue reference: #48837
This PR simply combines two existing API call: client registration and
port announcement into one, with no functionality change.

The plan is to
- Add new APIs at callee side, so we don't have production effect;
- Do no-op code refactor at caller side, so it's easier to refactor and
merge API calls
- Connect caller and callee with the new API calls, after (1) and (2)
the code structure should be at a clean state which is friendly and easy
for API merge.

---------

Signed-off-by: hjiang <hjiang@anyscale.com>
  • Loading branch information
dentiny authored Dec 21, 2024
1 parent 5d359e4 commit 6457898
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 59 deletions.
20 changes: 20 additions & 0 deletions src/ray/raylet/format/node_manager.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ enum MessageType:int {
ConnectClient,
// Subscribe to Plasma updates.
SubscribePlasmaReady,
// [RegisterClientWithPort] series is the combination for [RegisterClient] and [AnnounceWorkerPort].
//
// Send an initial connection message to the raylet with port assigned. This is sent
// from a worker or driver to a raylet.
// The corresponding response type is [AnnounceWorkerPortReply].
RegisterClientWithPortRequest,
}

table Task {
Expand Down Expand Up @@ -144,6 +150,20 @@ table RegisterClientReply {
port: int;
}

table RegisterClientWithPortRequest {
// Request to register client.
request_client_request: RegisterClientRequest;
// Request to assign port.
announcement_port_request: AnnounceWorkerPort;
}

table RegisterClientWithPortResponse {
// Response to register client.
request_client_request: [RegisterClientReply];
// Response to assign port.
announcement_port_request: [AnnounceWorkerPortReply];
}

table AnnounceWorkerPort {
// Port that this worker is listening on.
port: int;
Expand Down
190 changes: 131 additions & 59 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,9 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr<ClientConnection> &
case protocol::MessageType::AnnounceWorkerPort: {
ProcessAnnounceWorkerPortMessage(client, message_data);
} break;
case protocol::MessageType::RegisterClientWithPortRequest: {
ProcessRegisterClientAndAnnouncePortMessage(client, message_data);
} break;
case protocol::MessageType::ActorCreationTaskDone: {
if (registered_worker) {
// Worker may send this message after it was disconnected.
Expand Down Expand Up @@ -1308,9 +1311,17 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr<ClientConnection> &

void NodeManager::ProcessRegisterClientRequestMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
auto *message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
RAY_UNUSED(
ProcessRegisterClientRequestMessageImpl(client, message, /*port=*/std::nullopt));
}

Status NodeManager::ProcessRegisterClientRequestMessageImpl(
const std::shared_ptr<ClientConnection> &client,
const ray::protocol::RegisterClientRequest *message,
std::optional<int> port) {
client->Register();

auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
Language language = static_cast<Language>(message->language());
const JobID job_id = from_flatbuf<JobID>(*message->job_id());
const int runtime_env_hash = static_cast<int>(message->runtime_env_hash());
Expand All @@ -1326,6 +1337,7 @@ void NodeManager::ProcessRegisterClientRequestMessage(
worker_type == rpc::WorkerType::RESTORE_WORKER) {
RAY_CHECK(job_id.IsNil());
}

auto worker = std::static_pointer_cast<WorkerInterface>(
std::make_shared<Worker>(job_id,
runtime_env_hash,
Expand All @@ -1337,33 +1349,53 @@ void NodeManager::ProcessRegisterClientRequestMessage(
client_call_manager_,
worker_startup_token));

auto send_reply_callback = [this, client](Status status, int assigned_port) {
flatbuffers::FlatBufferBuilder fbb;
auto reply =
ray::protocol::CreateRegisterClientReply(fbb,
status.ok(),
fbb.CreateString(status.ToString()),
to_flatbuf(fbb, self_node_id_),
assigned_port);
fbb.Finish(reply);
client->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::RegisterClientReply),
fbb.GetSize(),
fbb.GetBufferPointer(),
[this, client](const ray::Status &status) {
if (!status.ok()) {
DisconnectClient(client,
rpc::WorkerExitType::SYSTEM_ERROR,
"Worker is failed because the raylet couldn't reply the "
"registration request: " +
status.ToString());
}
});
};
std::function<void(Status, int)> send_reply_callback;
if (port.has_value()) {
worker->SetAssignedPort(*port);
} else {
send_reply_callback = [this, client](Status status, int assigned_port) {
flatbuffers::FlatBufferBuilder fbb;
auto reply =
ray::protocol::CreateRegisterClientReply(fbb,
status.ok(),
fbb.CreateString(status.ToString()),
to_flatbuf(fbb, self_node_id_),
assigned_port);
fbb.Finish(reply);
client->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::RegisterClientReply),
fbb.GetSize(),
fbb.GetBufferPointer(),
[this, client](const ray::Status &status) {
if (!status.ok()) {
DisconnectClient(client,
rpc::WorkerExitType::SYSTEM_ERROR,
"Worker is failed because the raylet couldn't reply the "
"registration request: " +
status.ToString());
}
});
};
}

if (worker_type == rpc::WorkerType::WORKER ||
worker_type == rpc::WorkerType::SPILL_WORKER ||
worker_type == rpc::WorkerType::RESTORE_WORKER) {
// Register the new worker.
return RegisterForNewWorker(
worker, pid, worker_startup_token, std::move(send_reply_callback));
}

// Register the new driver.
return RegisterForNewDriver(
worker, pid, job_id, message, std::move(send_reply_callback));
}

Status NodeManager::RegisterForNewWorker(
std::shared_ptr<WorkerInterface> worker,
pid_t pid,
const StartupToken &worker_startup_token,
std::function<void(Status, int)> send_reply_callback) {
if (send_reply_callback) {
auto status = worker_pool_.RegisterWorker(
worker, pid, worker_startup_token, send_reply_callback);
if (!status.ok()) {
Expand All @@ -1372,23 +1404,44 @@ void NodeManager::ProcessRegisterClientRequestMessage(
// maximum_startup_concurrency).
cluster_task_manager_->ScheduleAndDispatchTasks();
}
} else {
// Register the new driver.
RAY_CHECK(pid >= 0);
worker->SetProcess(Process::FromPid(pid));
// Compute a dummy driver task id from a given driver.
// The task id set in the worker here should be consistent with the task
// id set in the core worker.
const TaskID driver_task_id = TaskID::ForDriverTask(job_id);
worker->AssignTaskId(driver_task_id);
rpc::JobConfig job_config;
job_config.ParseFromString(message->serialized_job_config()->str());
RAY_UNUSED(worker_pool_.RegisterDriver(worker, job_config, send_reply_callback));
return status;
}

return worker_pool_.RegisterWorker(worker, pid, worker_startup_token);
}

Status NodeManager::RegisterForNewDriver(
std::shared_ptr<WorkerInterface> worker,
pid_t pid,
const JobID &job_id,
const ray::protocol::RegisterClientRequest *message,
std::function<void(Status, int)> send_reply_callback) {
RAY_CHECK_GE(pid, 0);
worker->SetProcess(Process::FromPid(pid));
// Compute a dummy driver task id from a given driver.
// The task id set in the worker here should be consistent with the task
// id set in the core worker.
const TaskID driver_task_id = TaskID::ForDriverTask(job_id);
worker->AssignTaskId(driver_task_id);
rpc::JobConfig job_config;
job_config.ParseFromString(message->serialized_job_config()->str());

if (send_reply_callback) {
return worker_pool_.RegisterDriver(worker, job_config, send_reply_callback);
}

return worker_pool_.RegisterDriver(worker, job_config);
}

void NodeManager::ProcessAnnounceWorkerPortMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
auto *message = flatbuffers::GetRoot<protocol::AnnounceWorkerPort>(message_data);
ProcessAnnounceWorkerPortMessageImpl(client, message);
}

void NodeManager::ProcessAnnounceWorkerPortMessageImpl(
const std::shared_ptr<ClientConnection> &client,
const ray::protocol::AnnounceWorkerPort *message) {
bool is_worker = true;
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
if (worker == nullptr) {
Expand All @@ -1398,7 +1451,6 @@ void NodeManager::ProcessAnnounceWorkerPortMessage(
RAY_CHECK(worker != nullptr) << "No worker exists for CoreWorker with client: "
<< client->DebugString();

auto message = flatbuffers::GetRoot<protocol::AnnounceWorkerPort>(message_data);
int port = message->port();
worker->Connect(port);
if (is_worker) {
Expand Down Expand Up @@ -1427,31 +1479,51 @@ void NodeManager::ProcessAnnounceWorkerPortMessage(

RAY_CHECK_OK(
gcs_client_->Jobs().AsyncAdd(job_data_ptr, [this, client](Status status) {
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to add job to GCS: " << status.ToString();
}
// Write the reply back.
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateAnnounceWorkerPortReply(
fbb, status.ok(), fbb.CreateString(status.ToString()));
fbb.Finish(message);

client->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::AnnounceWorkerPortReply),
fbb.GetSize(),
fbb.GetBufferPointer(),
[this, client](const ray::Status &status) {
if (!status.ok()) {
DisconnectClient(client,
rpc::WorkerExitType::SYSTEM_ERROR,
"Failed to send AnnounceWorkerPortReply to client: " +
status.ToString());
}
});
SendPortAnnouncementResponse(client, std::move(status));
}));
}
}

void NodeManager::SendPortAnnouncementResponse(
const std::shared_ptr<ClientConnection> &client, Status status) {
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to add job to GCS: " << status.ToString();
}
// Write the reply back.
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateAnnounceWorkerPortReply(
fbb, status.ok(), fbb.CreateString(status.ToString()));
fbb.Finish(message);

client->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::AnnounceWorkerPortReply),
fbb.GetSize(),
fbb.GetBufferPointer(),
[this, client](const ray::Status &status) {
if (!status.ok()) {
DisconnectClient(
client,
rpc::WorkerExitType::SYSTEM_ERROR,
"Failed to send AnnounceWorkerPortReply to client: " + status.ToString());
}
});
}

void NodeManager::ProcessRegisterClientAndAnnouncePortMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
auto *message =
flatbuffers::GetRoot<protocol::RegisterClientWithPortRequest>(message_data);
const ray::protocol::AnnounceWorkerPort *announce_port_msg =
message->announcement_port_request();
auto status = ProcessRegisterClientRequestMessageImpl(
client, message->request_client_request(), announce_port_msg->port());
if (!status.ok()) {
SendPortAnnouncementResponse(client, std::move(status));
return;
}
RAY_UNUSED(ProcessAnnounceWorkerPortMessageImpl(client, announce_port_msg));
}

void NodeManager::HandleWorkerAvailable(const std::shared_ptr<WorkerInterface> &worker) {
RAY_CHECK(worker);

Expand Down
27 changes: 27 additions & 0 deletions src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
/// \return Void.
void ProcessRegisterClientRequestMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data);
Status ProcessRegisterClientRequestMessageImpl(
const std::shared_ptr<ClientConnection> &client,
const ray::protocol::RegisterClientRequest *message,
std::optional<int> port);

// Register a new worker into worker pool.
Status RegisterForNewWorker(std::shared_ptr<WorkerInterface> worker,
pid_t pid,
const StartupToken &worker_startup_token,
std::function<void(Status, int)> send_reply_callback = {});
// Register a new driver into worker pool.
Status RegisterForNewDriver(std::shared_ptr<WorkerInterface> worker,
pid_t pid,
const JobID &job_id,
const ray::protocol::RegisterClientRequest *message,
std::function<void(Status, int)> send_reply_callback = {});

/// Process client message of AnnounceWorkerPort
///
Expand All @@ -440,6 +456,17 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
/// \return Void.
void ProcessAnnounceWorkerPortMessage(const std::shared_ptr<ClientConnection> &client,
const uint8_t *message_data);
void ProcessAnnounceWorkerPortMessageImpl(
const std::shared_ptr<ClientConnection> &client,
const ray::protocol::AnnounceWorkerPort *message);

// Send status of port announcement to client side.
void SendPortAnnouncementResponse(const std::shared_ptr<ClientConnection> &client,
Status status);

/// Process client registration and port announcement.
void ProcessRegisterClientAndAnnouncePortMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data);

/// Handle the case that a worker is available.
///
Expand Down
51 changes: 51 additions & 0 deletions src/ray/raylet/worker_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,37 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr<WorkerInterface> &worker
return Status::OK();
}

Status WorkerPool::RegisterWorker(const std::shared_ptr<WorkerInterface> &worker,
pid_t pid,
StartupToken worker_startup_token) {
RAY_CHECK(worker);
auto &state = GetStateForLanguage(worker->GetLanguage());
auto it = state.worker_processes.find(worker_startup_token);
if (it == state.worker_processes.end()) {
RAY_LOG(WARNING) << "Received a register request from an unknown token: "
<< worker_startup_token;
return Status::Invalid("Unknown worker");
}

auto process = Process::FromPid(pid);
worker->SetProcess(process);

auto &starting_process_info = it->second;
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(
end - starting_process_info.start_time);

// TODO(hjiang): Add tag to indicate whether port has been assigned beforehand.
STATS_worker_register_time_ms.Record(duration.count());
RAY_LOG(DEBUG) << "Registering worker " << worker->WorkerId() << " with pid " << pid
<< ", register cost: " << duration.count()
<< ", worker_type: " << rpc::WorkerType_Name(worker->GetWorkerType())
<< ", startup token: " << worker_startup_token;

state.registered_workers.insert(worker);
return Status::OK();
}

void WorkerPool::OnWorkerStarted(const std::shared_ptr<WorkerInterface> &worker) {
auto &state = GetStateForLanguage(worker->GetLanguage());
const StartupToken worker_startup_token = worker->GetStartupToken();
Expand Down Expand Up @@ -879,6 +910,26 @@ Status WorkerPool::RegisterDriver(const std::shared_ptr<WorkerInterface> &driver
return Status::OK();
}

Status WorkerPool::RegisterDriver(const std::shared_ptr<WorkerInterface> &driver,
const rpc::JobConfig &job_config) {
auto &state = GetStateForLanguage(driver->GetLanguage());
state.registered_drivers.insert(std::move(driver));
const auto job_id = driver->GetAssignedJobId();
HandleJobStarted(job_id, job_config);

if (driver->GetLanguage() == Language::JAVA) {
return Status::OK();
}

if (!first_job_registered_ && RayConfig::instance().prestart_worker_first_driver() &&
!RayConfig::instance().enable_worker_prestart()) {
RAY_LOG(DEBUG) << "PrestartDefaultCpuWorkers " << num_prestart_python_workers;
PrestartDefaultCpuWorkers(Language::PYTHON, num_prestart_python_workers);
}
first_job_registered_ = true;
return Status::OK();
}

std::shared_ptr<WorkerInterface> WorkerPool::GetRegisteredWorker(
const WorkerID &worker_id) const {
for (const auto &[_, state] : states_by_lang_) {
Expand Down
Loading

0 comments on commit 6457898

Please sign in to comment.