Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Replace GCS stubs in dashboard to use NewGcsAioClient. #46846

Merged
merged 4 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/ray/_private/gcs_aio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def __init__(
# Forwarded Properties.
self.address = self.inner.address
self.cluster_id = self.inner.cluster_id
# Note: these only exists in the new client.
self.get_all_actor_info = self.inner.async_get_all_actor_info
self.get_all_node_info = self.inner.async_get_all_node_info
self.kill_actor = self.inner.async_kill_actor


class AsyncProxy:
Expand Down
1 change: 1 addition & 0 deletions python/ray/dashboard/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ async def run(self):
self.gcs_aio_client = GcsAioClient(
address=gcs_address, nums_reconnect_retry=0
)
# TODO(ryw): once we removed the old gcs client, also remove this.
gcs_channel = GcsChannel(gcs_address=gcs_address, aio=True)
gcs_channel.connect()
self.aiogrpc_gcs_channel = gcs_channel.channel()
Expand Down
101 changes: 67 additions & 34 deletions python/ray/dashboard/modules/actor/actor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import os
import time
from collections import deque
from typing import Dict

import aiohttp.web

import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.utils as dashboard_utils
from ray import ActorID
from ray._private.gcs_pubsub import GcsAioActorSubscriber
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
from ray.core.generated import gcs_pb2, gcs_service_pb2, gcs_service_pb2_grpc
from ray.dashboard.datacenter import DataOrganizer, DataSource
from ray.dashboard.modules.actor import actor_consts

Expand Down Expand Up @@ -77,11 +79,54 @@ def actor_table_data_to_dict(message):
return light_message


class GetAllActorInfo:
"""
Gets all actor info from GCS via gRPC ActorInfoGcsService.GetAllActorInfo.
It makes the call via GcsAioClient or a direct gRPC stub, depends on the env var
RAY_USE_OLD_GCS_CLIENT.
"""

def __new__(cls, *args, **kwargs):
use_old_client = os.getenv("RAY_USE_OLD_GCS_CLIENT") == "1"
if use_old_client:
return GetAllActorInfoFromGrpc(*args, **kwargs)
else:
return GetAllActorInfoFromNewGcsClient(*args, **kwargs)


class GetAllActorInfoFromNewGcsClient:
def __init__(self, dashboard_head):
self.gcs_aio_client = dashboard_head.gcs_aio_client

async def __call__(self, timeout) -> Dict[ActorID, gcs_pb2.ActorTableData]:
return await self.gcs_aio_client.get_all_actor_info(timeout=timeout)


class GetAllActorInfoFromGrpc:
def __init__(self, dashboard_head):
gcs_channel = dashboard_head.aiogrpc_gcs_channel
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
gcs_channel
)

async def __call__(self, timeout) -> Dict[ActorID, gcs_pb2.ActorTableData]:
request = gcs_service_pb2.GetAllActorInfoRequest()
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=timeout
)
if reply.status.code != 0:
raise Exception(f"Failed to GetAllActorInfo: {reply.status.message}")
actors = {}
for message in reply.actor_table_data:
actors[ActorID.FromBinary(message.actorId)] = message
return actors


class ActorHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
# ActorInfoGcsService
self._gcs_actor_info_stub = None

self.get_all_actor_info = None
# A queue of dead actors in order of when they died
self.dead_actors_queue = deque()

Expand All @@ -95,33 +140,24 @@ async def _update_actors(self):
while True:
try:
logger.info("Getting all actor info from GCS.")
request = gcs_service_pb2.GetAllActorInfoRequest()
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=5
)
if reply.status.code == 0:
actors = {}
for message in reply.actor_table_data:
actor_table_data = actor_table_data_to_dict(message)
actors[actor_table_data["actorId"]] = actor_table_data
# Update actors.
DataSource.actors.reset(actors)
# Update node actors and job actors.
node_actors = {}
for actor_id, actor_table_data in actors.items():
node_id = actor_table_data["address"]["rayletId"]
# Update only when node_id is not Nil.
if node_id != actor_consts.NIL_NODE_ID:
node_actors.setdefault(node_id, {})[
actor_id
] = actor_table_data
DataSource.node_actors.reset(node_actors)
logger.info("Received %d actor info from GCS.", len(actors))
break
else:
raise Exception(
f"Failed to GetAllActorInfo: {reply.status.message}"
)

actors = await self.get_all_actor_info(timeout=5)
actor_dicts: Dict[str, dict] = {
actor_id.hex(): actor_table_data_to_dict(actor_table_data)
for actor_id, actor_table_data in actors.items()
}
# Update actors.
DataSource.actors.reset(actor_dicts)
# Update node actors and job actors.
node_actors = {}
for actor_id, actor_table_data in actor_dicts.items():
node_id = actor_table_data["address"]["rayletId"]
# Update only when node_id is not Nil.
if node_id != actor_consts.NIL_NODE_ID:
node_actors.setdefault(node_id, {})[actor_id] = actor_table_data
DataSource.node_actors.reset(node_actors)
logger.info("Received %d actor info from GCS.", len(actors))
break # breaks the while True.
except Exception:
logger.exception("Error Getting all actor info from GCS.")
await asyncio.sleep(
Expand Down Expand Up @@ -258,10 +294,7 @@ async def get_actor(self, req) -> aiohttp.web.Response:
)

async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
gcs_channel
)
self.get_all_actor_info = GetAllActorInfo(self._dashboard_head)
await asyncio.gather(self._update_actors(), self._cleanup_actors())

@staticmethod
Expand Down
80 changes: 56 additions & 24 deletions python/ray/dashboard/modules/node/node_head.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import json
import logging
import os
import time
from itertools import chain
from typing import Dict

import aiohttp.web
import grpc
Expand All @@ -11,6 +13,7 @@
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.utils as dashboard_utils
from ray import NodeID
from ray._private import ray_constants
from ray._private.ray_constants import DEBUG_AUTOSCALING_ERROR, DEBUG_AUTOSCALING_STATUS
from ray.autoscaler._private.util import (
Expand All @@ -19,6 +22,7 @@
parse_usage,
)
from ray.core.generated import (
gcs_pb2,
gcs_service_pb2,
gcs_service_pb2_grpc,
node_manager_pb2,
Expand Down Expand Up @@ -85,14 +89,52 @@ def node_stats_to_dict(message):
message.core_workers_stats.extend(core_workers_stats)


class GetAllNodeInfo:
"""
Gets all node info from GCS via gRPC NodeInfoGcsService.GetAllNodeInfo.
It makes the call via GcsAioClient or a direct gRPC stub, depending on the env var
RAY_USE_OLD_GCS_CLIENT.
"""

def __new__(cls, *args, **kwargs):
use_old_client = os.getenv("RAY_USE_OLD_GCS_CLIENT") == "1"
if use_old_client:
return GetAllNodeInfoFromGrpc(*args, **kwargs)
else:
return GetAllNodeInfoFromNewGcsClient(*args, **kwargs)


class GetAllNodeInfoFromNewGcsClient:
def __init__(self, dashboard_head):
self.gcs_aio_client = dashboard_head.gcs_aio_client

async def __call__(self, timeout) -> Dict[NodeID, gcs_pb2.GcsNodeInfo]:
return await self.gcs_aio_client.get_all_node_info(timeout=timeout)


class GetAllNodeInfoFromGrpc:
def __init__(self, dashboard_head):
gcs_channel = dashboard_head.aiogrpc_gcs_channel
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
gcs_channel
)

async def __call__(self, timeout) -> Dict[NodeID, gcs_pb2.GcsNodeInfo]:
request = gcs_service_pb2.GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout)
if reply.status.code != 0:
raise Exception(f"Failed to GetAllNodeInfo: {reply.status.message}")
nodes = {}
for message in reply.node_info_list:
nodes[NodeID.FromBinary(message.nodeId)] = message
return nodes


class NodeHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._stubs = {}
# NodeInfoGcsService
self._gcs_node_info_stub = None
# NodeResourceInfoGcsService
self._gcs_node_resource_info_sub = None
self.get_all_node_info = None
self._collect_memory_info = False
DataSource.nodes.signal.append(self._update_stubs)
# Total number of node updates happened.
Expand Down Expand Up @@ -137,18 +179,15 @@ async def _get_nodes(self):
Returns:
A dict of information about the nodes in the cluster.
"""
request = gcs_service_pb2.GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(
request, timeout=node_consts.GCS_RPC_TIMEOUT_SECONDS
)
if reply.status.code == 0:
result = {}
for node_info in reply.node_info_list:
node_info_dict = gcs_node_info_to_dict(node_info)
result[node_info_dict["nodeId"]] = node_info_dict
return result
else:
logger.error("Failed to GetAllNodeInfo: %s", reply.status.message)
try:
nodes = await self.get_all_node_info(timeout=5)
return {
node_id.hex(): gcs_node_info_to_dict(node_info)
for node_id, node_info in nodes.items()
}
except Exception:
logger.exception("Failed to GetAllNodeInfo.")
raise

async def _update_nodes(self):
# TODO(fyrestone): Refactor code for updating actor / node / job.
Expand Down Expand Up @@ -394,14 +433,7 @@ async def _update_node_stats(self):
DataSource.node_stats[node_id] = reply_dict

async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
gcs_channel
)
self._gcs_node_resource_info_stub = (
gcs_service_pb2_grpc.NodeResourceInfoGcsServiceStub(gcs_channel)
)

self.get_all_node_info = GetAllNodeInfo(self._dashboard_head)
await asyncio.gather(
self._update_nodes(),
self._update_node_stats(),
Expand Down
71 changes: 60 additions & 11 deletions python/ray/dashboard/modules/snapshot/snapshot_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.utils as dashboard_utils
from ray import ActorID
from ray._private.gcs_aio_client import GcsAioClient
from ray._private.pydantic_compat import BaseModel, Extra, Field, validator
from ray._private.storage import _load_class
Expand Down Expand Up @@ -73,11 +74,62 @@ def reason_required(cls, v, values, **kwargs):
return v


class KillActor:
"""
Kill an actor via GCS using gRPC ActorInfoGcsService.KillActorViaGcs.
It makes the call via GcsAioClient or a direct gRPC stub, depending on the env var
RAY_USE_OLD_GCS_CLIENT.
"""

def __new__(cls, *args, **kwargs):
use_old_client = os.getenv("RAY_USE_OLD_GCS_CLIENT") == "1"
if use_old_client:
return KillActorViaGcsFromGrpc(*args, **kwargs)
else:
return KillActorViaGcsFromNewGcsClient(*args, **kwargs)


class KillActorViaGcsFromNewGcsClient:
def __init__(self, dashboard_head):
self.gcs_aio_client = dashboard_head.gcs_aio_client

async def async_kill_actor(
self,
actor_id: ActorID,
force_kill: bool,
no_restart: bool,
timeout: Optional[float] = None,
):
return await self.gcs_aio_client.kill_actor(
actor_id, force_kill, no_restart, timeout
)


class KillActorViaGcsFromGrpc:
def __init__(self, dashboard_head):
gcs_channel = dashboard_head.aiogrpc_gcs_channel
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
gcs_channel
)

async def async_kill_actor(
self,
actor_id: ActorID,
force_kill: bool,
no_restart: bool,
timeout: Optional[float] = None,
):
request = gcs_service_pb2.KillActorViaGcsRequest()
request.actor_id = bytes.fromhex(actor_id.hex())
request.force_kill = force_kill
request.no_restart = no_restart
await self._gcs_actor_info_stub.KillActorViaGcs(request, timeout=timeout)


class APIHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._gcs_actor_info_stub = None
self._dashboard_head = dashboard_head
self._kill_actor = None
self._gcs_aio_client: GcsAioClient = dashboard_head.gcs_aio_client
self._job_info_client = None
# For offloading CPU intensive work.
Expand All @@ -95,12 +147,11 @@ async def kill_actor_gcs(self, req) -> aiohttp.web.Response:
success=False, message="actor_id is required."
)

request = gcs_service_pb2.KillActorViaGcsRequest()
request.actor_id = bytes.fromhex(actor_id)
request.force_kill = force_kill
request.no_restart = no_restart
await self._gcs_actor_info_stub.KillActorViaGcs(
request, timeout=SNAPSHOT_API_TIMEOUT_SECONDS
await self._kill_actor.async_kill_actor(
ActorID.from_hex(actor_id),
force_kill,
no_restart,
timeout=SNAPSHOT_API_TIMEOUT_SECONDS,
)

message = (
Expand Down Expand Up @@ -231,9 +282,7 @@ async def _get_job_activity_info(self, timeout: int) -> RayActivityResponse:
)

async def run(self, server):
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel
)
self._kill_actor = KillActor(self._dashboard_head)
# Lazily constructed because dashboard_head's gcs_aio_client
# is lazily constructed
if not self._job_info_client:
Expand Down
Loading