From 7905b1b3703408f3e4fc7695c54e9d209a5b9aca Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 27 Aug 2024 14:51:06 -0700 Subject: [PATCH 01/17] node delta Signed-off-by: Ruiyang Wang --- python/ray/_private/gcs_pubsub.py | 86 +++++---- .../ray/dashboard/modules/node/node_consts.py | 13 +- .../ray/dashboard/modules/node/node_head.py | 170 ++++++++---------- 3 files changed, 131 insertions(+), 138 deletions(-) diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index fb4ea8d1bdd1..4774dab69ce4 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -13,7 +13,6 @@ from grpc.experimental import aio as aiogrpc import ray._private.gcs_utils as gcs_utils -import ray._private.logging_utils as logging_utils from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2 from ray.core.generated import common_pb2 @@ -90,39 +89,6 @@ def _should_terminate_polling(e: grpc.RpcError) -> None: return True return False - @staticmethod - def _pop_error_info(queue): - if len(queue) == 0: - return None, None - msg = queue.popleft() - return msg.key_id, msg.error_info_message - - @staticmethod - def _pop_log_batch(queue): - if len(queue) == 0: - return None - msg = queue.popleft() - return logging_utils.log_batch_proto_to_dict(msg.log_batch_message) - - @staticmethod - def _pop_resource_usage(queue): - if len(queue) == 0: - return None, None - msg = queue.popleft() - return msg.key_id.decode(), msg.node_resource_usage_message.json - - @staticmethod - def _pop_actors(queue, batch_size=100): - if len(queue) == 0: - return [] - popped = 0 - msgs = [] - while len(queue) > 0 and popped < batch_size: - msg = queue.popleft() - msgs.append((msg.key_id, msg.actor_message)) - popped += 1 - return msgs - class GcsAioPublisher(_PublisherBase): """Publisher to GCS. Uses async io.""" @@ -222,7 +188,7 @@ async def _poll(self, timeout=None) -> None: self._max_processed_sequence_id = 0 for msg in poll.result().pub_messages: if msg.sequence_id <= self._max_processed_sequence_id: - logger.warn(f"Ignoring out of order message {msg}") + logger.warning(f"Ignoring out of order message {msg}") continue self._max_processed_sequence_id = msg.sequence_id self._queue.append(msg) @@ -266,6 +232,13 @@ async def poll(self, timeout=None) -> Tuple[bytes, str]: await self._poll(timeout=timeout) return self._pop_resource_usage(self._queue) + @staticmethod + def _pop_resource_usage(queue): + if len(queue) == 0: + return None, None + msg = queue.popleft() + return msg.key_id.decode(), msg.node_resource_usage_message.json + class GcsAioActorSubscriber(_AioSubscriber): def __init__( @@ -288,3 +261,46 @@ async def poll(self, timeout=None, batch_size=500) -> List[Tuple[bytes, str]]: """ await self._poll(timeout=timeout) return self._pop_actors(self._queue, batch_size=batch_size) + + @staticmethod + def _pop_actors(queue, batch_size=100): + if len(queue) == 0: + return [] + popped = 0 + msgs = [] + while len(queue) > 0 and popped < batch_size: + msg = queue.popleft() + msgs.append((msg.key_id, msg.actor_message)) + popped += 1 + return msgs + + +class GcsAioNodeInfoSubscriber(_AioSubscriber): + def __init__( + self, + worker_id: bytes = None, + address: str = None, + channel: grpc.Channel = None, + ): + super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel) + + async def poll(self, timeout=None) -> Tuple[bytes, str]: + """Polls for new resource usage message. + + Returns: + A tuple of string reporter ID and resource usage json string. + """ + await self._poll(timeout=timeout) + return self._pop_node_infos(self._queue) + + @staticmethod + def _pop_node_infos(queue, batch_size=100): + if len(queue) == 0: + return [] + popped = 0 + msgs = [] + while len(queue) > 0 and popped < batch_size: + msg = queue.popleft() + msgs.append((msg.key_id, msg.node_info_message)) + popped += 1 + return msgs diff --git a/python/ray/dashboard/modules/node/node_consts.py b/python/ray/dashboard/modules/node/node_consts.py index 0e9b2465da95..f38c3f929834 100644 --- a/python/ray/dashboard/modules/node/node_consts.py +++ b/python/ray/dashboard/modules/node/node_consts.py @@ -1,16 +1,11 @@ -from ray._private.ray_constants import env_float, env_integer +from ray._private.ray_constants import env_integer NODE_STATS_UPDATE_INTERVAL_SECONDS = env_integer( "NODE_STATS_UPDATE_INTERVAL_SECONDS", 5 ) +# Deprecated, not used. UPDATE_NODES_INTERVAL_SECONDS = env_integer("UPDATE_NODES_INTERVAL_SECONDS", 5) -# Until the head node is registered, -# the API server is doing more frequent update -# with this interval. -FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS = env_float( - "FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS", 0.1 -) -# If the head node is not updated within -# this timeout, it will stop frequent update. +# Used to set a time range to frequently update the node stats. +# Now, it's timeout for a warning message if the head node is not registered. FREQUENT_UPDATE_TIMEOUT_SECONDS = env_integer("FREQUENT_UPDATE_TIMEOUT_SECONDS", 10) MAX_COUNT_OF_GCS_RPC_ERROR = 10 diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index cf521fa22b0d..d7e33c1b3676 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -4,7 +4,7 @@ import os import time from itertools import chain -from typing import Dict +from typing import AsyncGenerator, Dict import aiohttp.web import grpc @@ -15,6 +15,7 @@ import ray.dashboard.utils as dashboard_utils from ray import NodeID from ray._private import ray_constants +from ray._private.gcs_pubsub import GcsAioNodeInfoSubscriber from ray._private.ray_constants import DEBUG_AUTOSCALING_ERROR, DEBUG_AUTOSCALING_STATUS from ray._private.utils import get_or_create_event_loop from ray.autoscaler._private.util import ( @@ -32,10 +33,7 @@ from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS from ray.dashboard.datacenter import DataOrganizer, DataSource from ray.dashboard.modules.node import node_consts -from ray.dashboard.modules.node.node_consts import ( - FREQUENT_UPDATE_TIMEOUT_SECONDS, - FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS, -) +from ray.dashboard.modules.node.node_consts import FREQUENT_UPDATE_TIMEOUT_SECONDS from ray.dashboard.utils import async_loop_forever logger = logging.getLogger(__name__) @@ -138,8 +136,6 @@ def __init__(self, dashboard_head): self.get_all_node_info = None self._collect_memory_info = False DataSource.nodes.signal.append(self._update_stubs) - # Total number of node updates happened. - self._node_update_cnt = 0 # The time where the module is started. self._module_start_time = time.time() # The time it takes until the head node is registered. None means @@ -170,104 +166,90 @@ def get_internal_states(self): "head_node_registration_time_s": self._head_node_registration_time_s, "registered_nodes": len(DataSource.nodes), "registered_agents": len(DataSource.agents), - "node_update_count": self._node_update_cnt, "module_lifetime_s": time.time() - self._module_start_time, } - async def _get_nodes(self): - """Read the client table. + async def _subscribe_nodes(self) -> AsyncGenerator[dict]: + """ + Yields the initial state of all nodes, then yields the updated state of nodes. - Returns: - A dict of information about the nodes in the cluster. + It makes GetAllNodeInfo call only once after the subscription is done, to get + the initial state of the nodes. """ - try: - nodes = await self.get_all_node_info(timeout=GCS_RPC_TIMEOUT_SECONDS) - return { - node_id.hex(): gcs_node_info_to_dict(node_info) - for node_id, node_info in nodes.items() - } - except Exception: - logger.exception("Failed to GetAllNodeInfo.") - raise + gcs_addr = self._gcs_address + subscriber = GcsAioNodeInfoSubscriber(address=gcs_addr) + await subscriber.subscribe() + + # Get all node info from GCS. For TOCTOU, it happens after the subscription. + all_node_info = await self.get_all_node_info(timeout=GCS_RPC_TIMEOUT_SECONDS) + for node_info in all_node_info.values(): + yield gcs_node_info_to_dict(node_info) - async def _update_nodes(self): - # TODO(fyrestone): Refactor code for updating actor / node / job. - # Subscribe actor channel. while True: try: - nodes = await self._get_nodes() - - alive_node_ids = [] - alive_node_infos = [] - for node in nodes.values(): - node_id = node["nodeId"] - if node["isHeadNode"] and not self._head_node_registration_time_s: - self._head_node_registration_time_s = ( - time.time() - self._module_start_time - ) - # Put head node ID in the internal KV to be read by JobAgent. - # TODO(architkulkarni): Remove once State API exposes which - # node is the head node. - await self._gcs_aio_client.internal_kv_put( - ray_constants.KV_HEAD_NODE_ID_KEY, - node_id.encode(), - overwrite=True, - namespace=ray_constants.KV_NAMESPACE_JOB, - timeout=GCS_RPC_TIMEOUT_SECONDS, - ) - assert node["state"] in ["ALIVE", "DEAD"] - if node["state"] == "ALIVE": - alive_node_ids.append(node_id) - alive_node_infos.append(node) - - agents = dict(DataSource.agents) - for node_id in alive_node_ids: - # Since the agent fate shares with a raylet, - # the agent port will never change once it is discovered. - if node_id not in agents: - key = ( - f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" - f"{node_id}" - ) - agent_port = await self._gcs_aio_client.internal_kv_get( - key.encode(), - namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - timeout=GCS_RPC_TIMEOUT_SECONDS, - ) - if agent_port: - agents[node_id] = json.loads(agent_port) - for node_id in agents.keys() - set(alive_node_ids): - agents.pop(node_id, None) - - DataSource.agents.reset(agents) - DataSource.nodes.reset(nodes) + published = await subscriber.poll(batch_size=200) + for node_id, node_info in published: + if node_id is not None: + yield gcs_node_info_to_dict(node_info) + # yield control to APIs + await asyncio.sleep(0) except Exception: - logger.exception("Error updating nodes.") - finally: - self._node_update_cnt += 1 - # _head_node_registration_time_s == None if head node is not - # registered. - head_node_not_registered = not self._head_node_registration_time_s - # Until the head node is registered, we update the - # node status more frequently. - # If the head node is not updated after 10 seconds, it just stops - # doing frequent update to avoid unexpected edge case. + logger.exception("Error updating nodes from subscriber.") + + async def _update_node(self, node: dict): + node_id = node["nodeId"] # hex + if node["isHeadNode"] and not self._head_node_registration_time_s: + self._head_node_registration_time_s = time.time() - self._module_start_time + # Put head node ID in the internal KV to be read by JobAgent. + # TODO(architkulkarni): Remove once State API exposes which + # node is the head node. + await self._gcs_aio_client.internal_kv_put( + ray_constants.KV_HEAD_NODE_ID_KEY, + node_id.encode(), + overwrite=True, + namespace=ray_constants.KV_NAMESPACE_JOB, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + assert node["state"] in ["ALIVE", "DEAD"] + is_alive = node["state"] == "ALIVE" + # prepare agents for alive node, and pop for dead node + if is_alive: + if node_id not in DataSource.agents: + key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}" + agent_port = await self._gcs_aio_client.internal_kv_get( + key.encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + if agent_port: + DataSource.agents[node_id] = json.loads(agent_port) + else: + # not alive + DataSource.agents.pop(node_id, None) + + async def _update_nodes(self): + """ + Subscribe to node updates and update the internal states. If the head node is + not registered after FREQUENT_UPDATE_TIMEOUT_SECONDS, it logs a warning once. + """ + warning_shown = False + async for node in self._subscribe_nodes(): + await self._update_node(node) + if not self._head_node_registration_time_s: + # head node is not registered yet if ( - head_node_not_registered - and self._node_update_cnt * FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS - < FREQUENT_UPDATE_TIMEOUT_SECONDS + not warning_shown + and (time.time() - self._module_start_time) + > FREQUENT_UPDATE_TIMEOUT_SECONDS ): - await asyncio.sleep(FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS) - else: - if head_node_not_registered: - logger.warning( - "Head node is not registered even after " - f"{FREQUENT_UPDATE_TIMEOUT_SECONDS} seconds. " - "The API server might not work correctly. Please " - "report a Github issue. Internal states :" - f"{self.get_internal_states()}" - ) - await asyncio.sleep(node_consts.UPDATE_NODES_INTERVAL_SECONDS) + logger.warning( + "Head node is not registered even after " + f"{FREQUENT_UPDATE_TIMEOUT_SECONDS} seconds. " + "The API server might not work correctly. Please " + "report a Github issue. Internal states :" + f"{self.get_internal_states()}" + ) + warning_shown = True @routes.get("/internal/node_module") async def get_node_module_internal_state(self, req) -> aiohttp.web.Response: From 1e84eb5b5ee7aa9fe7904c990985f7895bb52ae7 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 27 Aug 2024 15:14:12 -0700 Subject: [PATCH 02/17] add dead node cache removal Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/node/node_head.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index d7e33c1b3676..928fdbed0dc8 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -3,6 +3,7 @@ import logging import os import time +from collections import deque from itertools import chain from typing import AsyncGenerator, Dict @@ -39,6 +40,9 @@ logger = logging.getLogger(__name__) routes = dashboard_optional_utils.DashboardHeadRouteTable +# This is consistent with gcs_node_manager.cc +MAX_NODES_TO_CACHE = int(os.environ.get("RAY_maximum_gcs_dead_node_cached_count", 1000)) + def gcs_node_info_to_dict(message): return dashboard_utils.message_to_dict( @@ -141,6 +145,8 @@ def __init__(self, dashboard_head): # The time it takes until the head node is registered. None means # head node hasn't been registered. self._head_node_registration_time_s = None + # Queue of dead nodes to be removed, up to MAX_NODES_TO_CACHE + self._dead_node_queue = deque() self._gcs_aio_client = dashboard_head.gcs_aio_client self._gcs_address = dashboard_head.gcs_address @@ -226,6 +232,10 @@ async def _update_node(self, node: dict): else: # not alive DataSource.agents.pop(node_id, None) + self._dead_node_queue.append(node_id) + if len(self._dead_node_queue) > MAX_NODES_TO_CACHE: + DataSource.nodes.pop(self._dead_node_queue.popleft(), None) + DataSource.nodes[node_id] = node async def _update_nodes(self): """ From e901af884859d75c42edc1a411a17f7fb8d99edd Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 27 Aug 2024 15:30:59 -0700 Subject: [PATCH 03/17] fix agent update Signed-off-by: Ruiyang Wang --- python/ray/_private/gcs_pubsub.py | 4 +-- .../ray/dashboard/modules/node/node_head.py | 32 +++++++++++++------ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index 4774dab69ce4..edafc7905ce7 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -284,14 +284,14 @@ def __init__( ): super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel) - async def poll(self, timeout=None) -> Tuple[bytes, str]: + async def poll(self, timeout=None, batch_size=100) -> Tuple[bytes, str]: """Polls for new resource usage message. Returns: A tuple of string reporter ID and resource usage json string. """ await self._poll(timeout=timeout) - return self._pop_node_infos(self._queue) + return self._pop_node_infos(self._queue, batch_size=batch_size) @staticmethod def _pop_node_infos(queue, batch_size=100): diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 928fdbed0dc8..9bc23f30d3bd 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -175,7 +175,7 @@ def get_internal_states(self): "module_lifetime_s": time.time() - self._module_start_time, } - async def _subscribe_nodes(self) -> AsyncGenerator[dict]: + async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: """ Yields the initial state of all nodes, then yields the updated state of nodes. @@ -221,14 +221,7 @@ async def _update_node(self, node: dict): # prepare agents for alive node, and pop for dead node if is_alive: if node_id not in DataSource.agents: - key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}" - agent_port = await self._gcs_aio_client.internal_kv_get( - key.encode(), - namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - timeout=GCS_RPC_TIMEOUT_SECONDS, - ) - if agent_port: - DataSource.agents[node_id] = json.loads(agent_port) + asyncio.create_task(self._update_agent(node_id)) else: # not alive DataSource.agents.pop(node_id, None) @@ -237,6 +230,27 @@ async def _update_node(self, node: dict): DataSource.nodes.pop(self._dead_node_queue.popleft(), None) DataSource.nodes[node_id] = node + async def _update_agent(self, node_id): + """ + Given a node, update the agent_port in DataSource.agents. Problem is it's not + present until agent.py starts, so we need to loop waiting for agent.py writes + its port to internal kv. + """ + key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}".encode() + while True: + agent_port = await self._gcs_aio_client.internal_kv_get( + key, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + if agent_port: + # Here we get the agent_port. But the node may be dead already. Only + # update DataSource.agents if the node is alive. + if DataSource.nodes.get(node_id, {}).get("state") == "ALIVE": + DataSource.agents[node_id] = json.loads(agent_port) + break + await asyncio.sleep(1) + async def _update_nodes(self): """ Subscribe to node updates and update the internal states. If the head node is From d4278936fc7d1ea5a3c0b5fe0030c1f12f1d383c Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 27 Aug 2024 15:34:03 -0700 Subject: [PATCH 04/17] fix agent Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/node/node_head.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 9bc23f30d3bd..f96da993464f 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -243,12 +243,13 @@ async def _update_agent(self, node_id): namespace=ray_constants.KV_NAMESPACE_DASHBOARD, timeout=GCS_RPC_TIMEOUT_SECONDS, ) + # The node may be dead already. Only update DataSource.agents if the node is + # still alive. + if DataSource.nodes.get(node_id, {}).get("state") != "ALIVE": + return if agent_port: - # Here we get the agent_port. But the node may be dead already. Only - # update DataSource.agents if the node is alive. - if DataSource.nodes.get(node_id, {}).get("state") == "ALIVE": - DataSource.agents[node_id] = json.loads(agent_port) - break + DataSource.agents[node_id] = json.loads(agent_port) + return await asyncio.sleep(1) async def _update_nodes(self): From f664d929dd1002c3cac61ebacd1016bb6a9a3836 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 27 Aug 2024 16:29:15 -0700 Subject: [PATCH 05/17] message_to_dict Signed-off-by: Ruiyang Wang --- python/ray/dashboard/datacenter.py | 1 + .../ray/dashboard/modules/actor/actor_head.py | 11 + .../ray/dashboard/modules/node/node_head.py | 38 +- .../dashboard/modules/node/tests/test_node.py | 1 + .../modules/reporter/reporter_head.py | 2 +- python/ray/dashboard/state_aggregator.py | 556 ++++++++++-------- python/ray/includes/gcs_client.pxi | 8 +- 7 files changed, 339 insertions(+), 278 deletions(-) diff --git a/python/ray/dashboard/datacenter.py b/python/ray/dashboard/datacenter.py index 2af6f853de43..c64263372242 100644 --- a/python/ray/dashboard/datacenter.py +++ b/python/ray/dashboard/datacenter.py @@ -218,6 +218,7 @@ async def _get_all_actors(actors): result = {} for index, (actor_id, actor) in enumerate(actors.items()): result[actor_id] = await DataOrganizer._get_actor(actor) + logger.error(f"{actors}, {result}") # There can be thousands of actors including dead ones. Processing # them all can take many seconds, which blocks all other requests # to the dashboard. The ideal solution might be to implement diff --git a/python/ray/dashboard/modules/actor/actor_head.py b/python/ray/dashboard/modules/actor/actor_head.py index fb0aa5ffd89d..e8083803c74d 100644 --- a/python/ray/dashboard/modules/actor/actor_head.py +++ b/python/ray/dashboard/modules/actor/actor_head.py @@ -11,6 +11,7 @@ import ray.dashboard.utils as dashboard_utils from ray import ActorID from ray._private.gcs_pubsub import GcsAioActorSubscriber +from ray._private.utils import get_or_create_event_loop from ray.core.generated import gcs_pb2, gcs_service_pb2, gcs_service_pb2_grpc from ray.dashboard.datacenter import DataOrganizer, DataSource from ray.dashboard.modules.actor import actor_consts @@ -142,6 +143,16 @@ async def _update_actors(self): logger.info("Getting all actor info from GCS.") actors = await self.get_all_actor_info(timeout=5) + + def convert(actors) -> Dict[str, dict]: + return { + actor_id.hex(): actor_table_data_to_dict(actor_table_data) + for actor_id, actor_table_data in actors.items() + } + + actor_dicts = await get_or_create_event_loop().run_in_executor( + self._dashboard_head._thread_pool_executor, convert, actors + ) actor_dicts: Dict[str, dict] = { actor_id.hex(): actor_table_data_to_dict(actor_table_data) for actor_id, actor_table_data in actors.items() diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index f96da993464f..f146995ab5d5 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -5,7 +5,7 @@ import time from collections import deque from itertools import chain -from typing import AsyncGenerator, Dict +from typing import AsyncGenerator, Dict, List import aiohttp.web import grpc @@ -50,21 +50,6 @@ def gcs_node_info_to_dict(message): ) -def gcs_stats_to_dict(message): - decode_keys = { - "actorId", - "jobId", - "taskId", - "parentTaskId", - "sourceActorId", - "callerId", - "rayletId", - "workerId", - "placementGroupId", - } - return dashboard_utils.message_to_dict(message, decode_keys) - - def node_stats_to_dict(message): decode_keys = { "actorId", @@ -153,7 +138,7 @@ def __init__(self, dashboard_head): async def _update_stubs(self, change): if change.old: node_id, node_info = change.old - self._stubs.pop(node_id) + self._stubs.pop(node_id, None) if change.new: # TODO(fyrestone): Handle exceptions. node_id, node_info = change.new @@ -188,15 +173,28 @@ async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: # Get all node info from GCS. For TOCTOU, it happens after the subscription. all_node_info = await self.get_all_node_info(timeout=GCS_RPC_TIMEOUT_SECONDS) - for node_info in all_node_info.values(): - yield gcs_node_info_to_dict(node_info) + + def convert(all_node_info) -> List[dict]: + return [ + gcs_node_info_to_dict(node_info) for node_info in all_node_info.items() + ] + + all_node_dict = await get_or_create_event_loop().run_in_executor( + self._dashboard_head._thread_pool_executor, convert, all_node_info + ) + for node in all_node_dict: + yield node while True: try: published = await subscriber.poll(batch_size=200) for node_id, node_info in published: if node_id is not None: - yield gcs_node_info_to_dict(node_info) + yield await get_or_create_event_loop().run_in_executor( + self._dashboard_head._thread_pool_executor, + gcs_node_info_to_dict, + node_info, + ) # yield control to APIs await asyncio.sleep(0) except Exception: diff --git a/python/ray/dashboard/modules/node/tests/test_node.py b/python/ray/dashboard/modules/node/tests/test_node.py index 18634d45ac15..b1c90645a295 100644 --- a/python/ray/dashboard/modules/node/tests/test_node.py +++ b/python/ray/dashboard/modules/node/tests/test_node.py @@ -104,6 +104,7 @@ def getpid(self): assert detail["raylet"]["isHeadNode"] is True assert "raylet" in detail["cmdline"][0] assert len(detail["workers"]) >= 2 + print(detail) assert len(detail["actors"]) == 2, detail["actors"] actor_worker_pids = set() diff --git a/python/ray/dashboard/modules/reporter/reporter_head.py b/python/ray/dashboard/modules/reporter/reporter_head.py index 3bb25e44be05..9ca0894d3c79 100644 --- a/python/ray/dashboard/modules/reporter/reporter_head.py +++ b/python/ray/dashboard/modules/reporter/reporter_head.py @@ -67,7 +67,7 @@ async def _update_stubs(self, change): if change.old: node_id, port = change.old ip = DataSource.nodes[node_id]["nodeManagerAddress"] - self._stubs.pop(ip) + self._stubs.pop(ip, None) if change.new: node_id, ports = change.new ip = DataSource.nodes[node_id]["nodeManagerAddress"] diff --git a/python/ray/dashboard/state_aggregator.py b/python/ray/dashboard/state_aggregator.py index 3027cf7d7f9d..c4ba56c954c2 100644 --- a/python/ray/dashboard/state_aggregator.py +++ b/python/ray/dashboard/state_aggregator.py @@ -241,32 +241,37 @@ async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - result = [] - for message in reply.actor_table_data: - data = protobuf_message_to_dict( - message=message, - fields_to_decode=[ - "actor_id", - "owner_id", - "job_id", - "node_id", - "placement_group_id", - ], + def transform(reply) -> ListApiResponse: + result = [] + for message in reply.actor_table_data: + data = protobuf_message_to_dict( + message=message, + fields_to_decode=[ + "actor_id", + "owner_id", + "job_id", + "node_id", + "placement_group_id", + ], + ) + result.append(data) + + num_after_truncation = len(result) + reply.num_filtered + result = self._filter(result, option.filters, ActorState, option.detail) + num_filtered = len(result) + + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["actor_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) - result.append(data) - - num_after_truncation = len(result) + reply.num_filtered - result = self._filter(result, option.filters, ActorState, option.detail) - num_filtered = len(result) - - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["actor_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse: @@ -283,26 +288,35 @@ async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiRespo except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - result = [] - for message in reply.placement_group_table_data: - data = protobuf_message_to_dict( - message=message, - fields_to_decode=["placement_group_id", "creator_job_id", "node_id"], + def transform(reply) -> ListApiResponse: + result = [] + for message in reply.placement_group_table_data: + data = protobuf_message_to_dict( + message=message, + fields_to_decode=[ + "placement_group_id", + "creator_job_id", + "node_id", + ], + ) + result.append(data) + num_after_truncation = len(result) + + result = self._filter( + result, option.filters, PlacementGroupState, option.detail + ) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["placement_group_id"]) + return ListApiResponse( + result=list(islice(result, option.limit)), + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) - result.append(data) - num_after_truncation = len(result) - result = self._filter( - result, option.filters, PlacementGroupState, option.detail - ) - num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["placement_group_id"]) - return ListApiResponse( - result=list(islice(result, option.limit)), - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: @@ -319,33 +333,39 @@ async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - result = [] - for message in reply.node_info_list: - data = protobuf_message_to_dict( - message=message, fields_to_decode=["node_id"] - ) - data["node_ip"] = data["node_manager_address"] - data["start_time_ms"] = int(data["start_time_ms"]) - data["end_time_ms"] = int(data["end_time_ms"]) - death_info = data.get("death_info", {}) - data["state_message"] = compose_state_message( - death_info.get("reason", None), death_info.get("reason_message", None) - ) + def transform(reply) -> ListApiResponse: + result = [] + for message in reply.node_info_list: + data = protobuf_message_to_dict( + message=message, fields_to_decode=["node_id"] + ) + data["node_ip"] = data["node_manager_address"] + data["start_time_ms"] = int(data["start_time_ms"]) + data["end_time_ms"] = int(data["end_time_ms"]) + death_info = data.get("death_info", {}) + data["state_message"] = compose_state_message( + death_info.get("reason", None), + death_info.get("reason_message", None), + ) - result.append(data) + result.append(data) - num_after_truncation = len(result) + reply.num_filtered - result = self._filter(result, option.filters, NodeState, option.detail) - num_filtered = len(result) + num_after_truncation = len(result) + reply.num_filtered + result = self._filter(result, option.filters, NodeState, option.detail) + num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["node_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["node_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: @@ -363,49 +383,61 @@ async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - result = [] - for message in reply.worker_table_data: - data = protobuf_message_to_dict( - message=message, fields_to_decode=["worker_id", "raylet_id"] + def transform(reply) -> ListApiResponse: + + result = [] + for message in reply.worker_table_data: + data = protobuf_message_to_dict( + message=message, fields_to_decode=["worker_id", "raylet_id"] + ) + data["worker_id"] = data["worker_address"]["worker_id"] + data["node_id"] = data["worker_address"]["raylet_id"] + data["ip"] = data["worker_address"]["ip_address"] + data["start_time_ms"] = int(data["start_time_ms"]) + data["end_time_ms"] = int(data["end_time_ms"]) + data["worker_launch_time_ms"] = int(data["worker_launch_time_ms"]) + data["worker_launched_time_ms"] = int(data["worker_launched_time_ms"]) + result.append(data) + + num_after_truncation = len(result) + reply.num_filtered + result = self._filter(result, option.filters, WorkerState, option.detail) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["worker_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) - data["worker_id"] = data["worker_address"]["worker_id"] - data["node_id"] = data["worker_address"]["raylet_id"] - data["ip"] = data["worker_address"]["ip_address"] - data["start_time_ms"] = int(data["start_time_ms"]) - data["end_time_ms"] = int(data["end_time_ms"]) - data["worker_launch_time_ms"] = int(data["worker_launch_time_ms"]) - data["worker_launched_time_ms"] = int(data["worker_launched_time_ms"]) - result.append(data) - - num_after_truncation = len(result) + reply.num_filtered - result = self._filter(result, option.filters, WorkerState, option.detail) - num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["worker_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse: try: - result = await self._client.get_job_info(timeout=option.timeout) - result = [job.dict() for job in result] + reply = await self._client.get_job_info(timeout=option.timeout) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + + def transform(reply) -> ListApiResponse: + result = [job.dict() for job in reply] total = len(result) result = self._filter(result, option.filters, JobState, option.detail) num_filtered = len(result) result.sort(key=lambda entry: entry["job_id"] or "") result = list(islice(result, option.limit)) - except DataSourceUnavailable: - raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - return ListApiResponse( - result=result, - total=total, - num_after_truncation=total, - num_filtered=num_filtered, + return ListApiResponse( + result=result, + total=total, + num_after_truncation=total, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: @@ -424,12 +456,10 @@ async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - def transform(reply): + def transform(reply) -> ListApiResponse: """ Transforms from proto to dict, applies filters, sorts, and truncates. This function is executed in a separate thread. - - Returns the ListApiResponse. """ result = [ protobuf_to_task_state_dict(message) for message in reply.events_by_task @@ -474,85 +504,90 @@ async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: return_exceptions=True, ) - unresponsive_nodes = 0 - worker_stats = [] - total_objects = 0 - for reply, _ in zip(replies, raylet_ids): - if isinstance(reply, DataSourceUnavailable): - unresponsive_nodes += 1 - continue - elif isinstance(reply, Exception): - raise reply - - total_objects += reply.total - for core_worker_stat in reply.core_workers_stats: - # NOTE: Set preserving_proto_field_name=False here because - # `construct_memory_table` requires a dictionary that has - # modified protobuf name - # (e.g., workerId instead of worker_id) as a key. - worker_stats.append( - protobuf_message_to_dict( - message=core_worker_stat, - fields_to_decode=["object_id"], - preserving_proto_field_name=False, + def transform(replies) -> ListApiResponse: + unresponsive_nodes = 0 + worker_stats = [] + total_objects = 0 + for reply, _ in zip(replies, raylet_ids): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + total_objects += reply.total + for core_worker_stat in reply.core_workers_stats: + # NOTE: Set preserving_proto_field_name=False here because + # `construct_memory_table` requires a dictionary that has + # modified protobuf name + # (e.g., workerId instead of worker_id) as a key. + worker_stats.append( + protobuf_message_to_dict( + message=core_worker_stat, + fields_to_decode=["object_id"], + preserving_proto_field_name=False, + ) ) + + partial_failure_warning = None + if len(raylet_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="raylet", + total=len(raylet_ids), + network_failures=unresponsive_nodes, + log_command="raylet.out", + ) + if unresponsive_nodes == len(raylet_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" ) - partial_failure_warning = None - if len(raylet_ids) > 0 and unresponsive_nodes > 0: - warning_msg = NODE_QUERY_FAILURE_WARNING.format( - type="raylet", - total=len(raylet_ids), - network_failures=unresponsive_nodes, - log_command="raylet.out", - ) - if unresponsive_nodes == len(raylet_ids): - raise DataSourceUnavailable(warning_msg) - partial_failure_warning = ( - f"The returned data may contain incomplete result. {warning_msg}" - ) + result = [] + memory_table = memory_utils.construct_memory_table(worker_stats) + for entry in memory_table.table: + data = entry.as_dict() + # `construct_memory_table` returns object_ref field which is indeed + # object_id. We do transformation here. + # TODO(sang): Refactor `construct_memory_table`. + data["object_id"] = data["object_ref"] + del data["object_ref"] + data["ip"] = data["node_ip_address"] + del data["node_ip_address"] + data["type"] = data["type"].upper() + data["task_status"] = ( + "NIL" if data["task_status"] == "-" else data["task_status"] + ) + result.append(data) - result = [] - memory_table = memory_utils.construct_memory_table(worker_stats) - for entry in memory_table.table: - data = entry.as_dict() - # `construct_memory_table` returns object_ref field which is indeed - # object_id. We do transformation here. - # TODO(sang): Refactor `construct_memory_table`. - data["object_id"] = data["object_ref"] - del data["object_ref"] - data["ip"] = data["node_ip_address"] - del data["node_ip_address"] - data["type"] = data["type"].upper() - data["task_status"] = ( - "NIL" if data["task_status"] == "-" else data["task_status"] - ) - result.append(data) - - # Add callsite warnings if it is not configured. - callsite_warning = [] - callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0) - if not callsite_enabled: - callsite_warning.append( - "Callsite is not being recorded. " - "To record callsite information for each ObjectRef created, set " - "env variable RAY_record_ref_creation_sites=1 during `ray start` " - "and `ray.init`." + # Add callsite warnings if it is not configured. + callsite_warning = [] + callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0) + if not callsite_enabled: + callsite_warning.append( + "Callsite is not being recorded. " + "To record callsite information for each ObjectRef created, set " + "env variable RAY_record_ref_creation_sites=1 during `ray start` " + "and `ray.init`." + ) + + num_after_truncation = len(result) + result = self._filter(result, option.filters, ObjectState, option.detail) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["object_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + partial_failure_warning=partial_failure_warning, + total=total_objects, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + warnings=callsite_warning, ) - num_after_truncation = len(result) - result = self._filter(result, option.filters, ObjectState, option.detail) - num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["object_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - partial_failure_warning=partial_failure_warning, - total=total_objects, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, - warnings=callsite_warning, + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, replies ) async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: @@ -574,66 +609,73 @@ async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: return_exceptions=True, ) - result = [] - unresponsive_nodes = 0 - total_runtime_envs = 0 - for node_id, reply in zip( - self._client.get_all_registered_runtime_env_agent_ids(), replies - ): - if isinstance(reply, DataSourceUnavailable): - unresponsive_nodes += 1 - continue - elif isinstance(reply, Exception): - raise reply - - total_runtime_envs += reply.total - states = reply.runtime_env_states - for state in states: - data = protobuf_message_to_dict(message=state, fields_to_decode=[]) - # Need to deserialize this field. - data["runtime_env"] = RuntimeEnv.deserialize( - data["runtime_env"] - ).to_dict() - data["node_id"] = node_id - result.append(data) - - partial_failure_warning = None - if len(agent_ids) > 0 and unresponsive_nodes > 0: - warning_msg = NODE_QUERY_FAILURE_WARNING.format( - type="agent", - total=len(agent_ids), - network_failures=unresponsive_nodes, - log_command="dashboard_agent.log", + def transform(replies) -> ListApiResponse: + result = [] + unresponsive_nodes = 0 + total_runtime_envs = 0 + for node_id, reply in zip( + self._client.get_all_registered_runtime_env_agent_ids(), replies + ): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + total_runtime_envs += reply.total + states = reply.runtime_env_states + for state in states: + data = protobuf_message_to_dict(message=state, fields_to_decode=[]) + # Need to deserialize this field. + data["runtime_env"] = RuntimeEnv.deserialize( + data["runtime_env"] + ).to_dict() + data["node_id"] = node_id + result.append(data) + + partial_failure_warning = None + if len(agent_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="agent", + total=len(agent_ids), + network_failures=unresponsive_nodes, + log_command="dashboard_agent.log", + ) + if unresponsive_nodes == len(agent_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + num_after_truncation = len(result) + result = self._filter( + result, option.filters, RuntimeEnvState, option.detail ) - if unresponsive_nodes == len(agent_ids): - raise DataSourceUnavailable(warning_msg) - partial_failure_warning = ( - f"The returned data may contain incomplete result. {warning_msg}" + num_filtered = len(result) + + # Sort to make the output deterministic. + def sort_func(entry): + # If creation time is not there yet (runtime env is failed + # to be created or not created yet, they are the highest priority. + # Otherwise, "bigger" creation time is coming first. + if "creation_time_ms" not in entry: + return float("inf") + elif entry["creation_time_ms"] is None: + return float("inf") + else: + return float(entry["creation_time_ms"]) + + result.sort(key=sort_func, reverse=True) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + partial_failure_warning=partial_failure_warning, + total=total_runtime_envs, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) - num_after_truncation = len(result) - result = self._filter(result, option.filters, RuntimeEnvState, option.detail) - num_filtered = len(result) - - # Sort to make the output deterministic. - def sort_func(entry): - # If creation time is not there yet (runtime env is failed - # to be created or not created yet, they are the highest priority. - # Otherwise, "bigger" creation time is coming first. - if "creation_time_ms" not in entry: - return float("inf") - elif entry["creation_time_ms"] is None: - return float("inf") - else: - return float(entry["creation_time_ms"]) - - result.sort(key=sort_func, reverse=True) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - partial_failure_warning=partial_failure_warning, - total=total_runtime_envs, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, replies ) async def list_cluster_events(self, *, option: ListApiOptions) -> ListApiResponse: @@ -644,25 +686,33 @@ async def list_cluster_events(self, *, option: ListApiOptions) -> ListApiRespons The schema of returned "dict" is equivalent to the `ClusterEventState` protobuf message. """ - result = [] - all_events = await self._client.get_all_cluster_events() - for _, events in all_events.items(): - for _, event in events.items(): - event["time"] = str(datetime.fromtimestamp(int(event["timestamp"]))) - result.append(event) - - num_after_truncation = len(result) - result.sort(key=lambda entry: entry["timestamp"]) - total = len(result) - result = self._filter(result, option.filters, ClusterEventState, option.detail) - num_filtered = len(result) - # Sort to make the output deterministic. - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + reply = await self._client.get_all_cluster_events() + + def transform(reply) -> ListApiResponse: + result = [] + for _, events in reply.items(): + for _, event in events.items(): + event["time"] = str(datetime.fromtimestamp(int(event["timestamp"]))) + result.append(event) + + num_after_truncation = len(result) + result.sort(key=lambda entry: entry["timestamp"]) + total = len(result) + result = self._filter( + result, option.filters, ClusterEventState, option.detail + ) + num_filtered = len(result) + # Sort to make the output deterministic. + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse: diff --git a/python/ray/includes/gcs_client.pxi b/python/ray/includes/gcs_client.pxi index d3850613ccdf..4d91a6cd0e4a 100644 --- a/python/ray/includes/gcs_client.pxi +++ b/python/ray/includes/gcs_client.pxi @@ -78,7 +78,7 @@ cdef class NewGcsClient: @property def cluster_id(self) -> ray.ClusterID: cdef CClusterID cluster_id = self.inner.get().GetClusterId() - return ray.ClusterID.from_binary(cluster_id.Binary()) + return ray.ClusterID(cluster_id.Binary()) ############################################################# # Internal KV sync methods @@ -612,7 +612,7 @@ cdef convert_get_all_node_info( for b in serialized_reply: proto = gcs_pb2.GcsNodeInfo() proto.ParseFromString(b) - node_table_data[NodeID.from_binary(proto.node_id)] = proto + node_table_data[NodeID(proto.node_id)] = proto return node_table_data, None except Exception as e: return None, e @@ -634,7 +634,7 @@ cdef convert_get_all_job_info( for b in serialized_reply: proto = gcs_pb2.JobTableData() proto.ParseFromString(b) - job_table_data[JobID.from_binary(proto.job_id)] = proto + job_table_data[JobID(proto.job_id)] = proto return job_table_data, None except Exception as e: return None, e @@ -653,7 +653,7 @@ cdef convert_get_all_actor_info( for b in serialized_reply: proto = gcs_pb2.ActorTableData() proto.ParseFromString(b) - actor_table_data[ActorID.from_binary(proto.actor_id)] = proto + actor_table_data[ActorID(proto.actor_id)] = proto return actor_table_data, None except Exception as e: return None, e From 654d183cdfedf3913a866e60e09dda6456de017a Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 27 Aug 2024 17:08:15 -0700 Subject: [PATCH 06/17] up Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/actor/actor_head.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/ray/dashboard/modules/actor/actor_head.py b/python/ray/dashboard/modules/actor/actor_head.py index e8083803c74d..223a94cb49bf 100644 --- a/python/ray/dashboard/modules/actor/actor_head.py +++ b/python/ray/dashboard/modules/actor/actor_head.py @@ -153,10 +153,6 @@ def convert(actors) -> Dict[str, dict]: actor_dicts = await get_or_create_event_loop().run_in_executor( self._dashboard_head._thread_pool_executor, convert, actors ) - actor_dicts: Dict[str, dict] = { - actor_id.hex(): actor_table_data_to_dict(actor_table_data) - for actor_id, actor_table_data in actors.items() - } # Update actors. DataSource.actors.reset(actor_dicts) # Update node actors and job actors. From e776dc112fec0c66ce7f6589c85e7b2730325331 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 27 Aug 2024 21:31:18 -0700 Subject: [PATCH 07/17] fix tpe Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/node/node_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index f146995ab5d5..7e6b8d09bb9f 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -176,7 +176,7 @@ async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: def convert(all_node_info) -> List[dict]: return [ - gcs_node_info_to_dict(node_info) for node_info in all_node_info.items() + gcs_node_info_to_dict(node_info) for node_info in all_node_info.values() ] all_node_dict = await get_or_create_event_loop().run_in_executor( From f94a3dd8884bc3d0ffb210f57cad4a4d32e4b191 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 13:42:13 -0700 Subject: [PATCH 08/17] fixes Signed-off-by: Ruiyang Wang --- python/ray/_private/gcs_pubsub.py | 15 ++-- python/ray/dashboard/datacenter.py | 1 - .../ray/dashboard/modules/node/node_consts.py | 16 ++-- .../ray/dashboard/modules/node/node_head.py | 76 ++++++++++--------- .../dashboard/modules/node/tests/test_node.py | 4 +- 5 files changed, 65 insertions(+), 47 deletions(-) diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index edafc7905ce7..757fb6e51fa0 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -15,6 +15,7 @@ import ray._private.gcs_utils as gcs_utils from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2 +from ray.core.generated import gcs_pb2 from ray.core.generated import common_pb2 from ray.core.generated import pubsub_pb2 @@ -253,11 +254,13 @@ def __init__( def queue_size(self): return len(self._queue) - async def poll(self, timeout=None, batch_size=500) -> List[Tuple[bytes, str]]: + async def poll( + self, timeout=None, batch_size=500 + ) -> List[Tuple[bytes, gcs_pb2.ActorTableData]]: """Polls for new actor message. Returns: - A tuple of binary actor ID and actor table data. + A list of tuples of binary actor ID and actor table data. """ await self._poll(timeout=timeout) return self._pop_actors(self._queue, batch_size=batch_size) @@ -284,11 +287,13 @@ def __init__( ): super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel) - async def poll(self, timeout=None, batch_size=100) -> Tuple[bytes, str]: - """Polls for new resource usage message. + async def poll( + self, timeout=None, batch_size=100 + ) -> List[Tuple[bytes, gcs_pb2.GcsNodeInfo]]: + """Polls for new node info message. Returns: - A tuple of string reporter ID and resource usage json string. + A list of tuples of (node_id, GcsNodeInfo). """ await self._poll(timeout=timeout) return self._pop_node_infos(self._queue, batch_size=batch_size) diff --git a/python/ray/dashboard/datacenter.py b/python/ray/dashboard/datacenter.py index c64263372242..2af6f853de43 100644 --- a/python/ray/dashboard/datacenter.py +++ b/python/ray/dashboard/datacenter.py @@ -218,7 +218,6 @@ async def _get_all_actors(actors): result = {} for index, (actor_id, actor) in enumerate(actors.items()): result[actor_id] = await DataOrganizer._get_actor(actor) - logger.error(f"{actors}, {result}") # There can be thousands of actors including dead ones. Processing # them all can take many seconds, which blocks all other requests # to the dashboard. The ideal solution might be to implement diff --git a/python/ray/dashboard/modules/node/node_consts.py b/python/ray/dashboard/modules/node/node_consts.py index f38c3f929834..941aecf3f8ee 100644 --- a/python/ray/dashboard/modules/node/node_consts.py +++ b/python/ray/dashboard/modules/node/node_consts.py @@ -3,9 +3,15 @@ NODE_STATS_UPDATE_INTERVAL_SECONDS = env_integer( "NODE_STATS_UPDATE_INTERVAL_SECONDS", 5 ) -# Deprecated, not used. -UPDATE_NODES_INTERVAL_SECONDS = env_integer("UPDATE_NODES_INTERVAL_SECONDS", 5) -# Used to set a time range to frequently update the node stats. -# Now, it's timeout for a warning message if the head node is not registered. -FREQUENT_UPDATE_TIMEOUT_SECONDS = env_integer("FREQUENT_UPDATE_TIMEOUT_SECONDS", 10) +RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT = env_integer( + "RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT", 10 +) MAX_COUNT_OF_GCS_RPC_ERROR = 10 +# This is consistent with gcs_node_manager.cc +MAX_DEAD_NODES_TO_CACHE = env_integer("RAY_maximum_gcs_dead_node_cached_count", 1000) +RAY_NODE_HEAD_SUBSCRIBER_POLL_SIZE = env_integer( + "RAY_NODE_HEAD_SUBSCRIBER_POLL_SIZE", 200 +) +RAY_NODE_HEAD_AGENT_POLL_INTERVAL_S = env_integer( + "RAY_NODE_HEAD_AGENT_POLL_INTERVAL_S", 1 +) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 7e6b8d09bb9f..5d45b4714d35 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -5,7 +5,7 @@ import time from collections import deque from itertools import chain -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, Dict, List, Tuple import aiohttp.web import grpc @@ -34,22 +34,31 @@ from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS from ray.dashboard.datacenter import DataOrganizer, DataSource from ray.dashboard.modules.node import node_consts -from ray.dashboard.modules.node.node_consts import FREQUENT_UPDATE_TIMEOUT_SECONDS +from ray.dashboard.modules.node.node_consts import ( + RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT, +) from ray.dashboard.utils import async_loop_forever logger = logging.getLogger(__name__) routes = dashboard_optional_utils.DashboardHeadRouteTable -# This is consistent with gcs_node_manager.cc -MAX_NODES_TO_CACHE = int(os.environ.get("RAY_maximum_gcs_dead_node_cached_count", 1000)) - -def gcs_node_info_to_dict(message): +def gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict: return dashboard_utils.message_to_dict( message, {"nodeId"}, always_print_fields_with_no_presence=True ) +def batch_gcs_node_info_to_dict(messages: List[gcs_pb2.GcsNodeInfo]) -> List[dict]: + return [gcs_node_info_to_dict(message) for message in messages] + + +def batch_updated_pairs_to_dict( + messages: List[Tuple[bytes, gcs_pb2.GcsNodeInfo]] +) -> List[dict]: + return [gcs_node_info_to_dict(node_info) for node_id_bytes, node_info in messages] + + def node_stats_to_dict(message): decode_keys = { "actorId", @@ -130,7 +139,7 @@ def __init__(self, dashboard_head): # The time it takes until the head node is registered. None means # head node hasn't been registered. self._head_node_registration_time_s = None - # Queue of dead nodes to be removed, up to MAX_NODES_TO_CACHE + # Queue of dead nodes to be removed, up to MAX_DEAD_NODES_TO_CACHE self._dead_node_queue = deque() self._gcs_aio_client = dashboard_head.gcs_aio_client self._gcs_address = dashboard_head.gcs_address @@ -174,31 +183,26 @@ async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: # Get all node info from GCS. For TOCTOU, it happens after the subscription. all_node_info = await self.get_all_node_info(timeout=GCS_RPC_TIMEOUT_SECONDS) - def convert(all_node_info) -> List[dict]: - return [ - gcs_node_info_to_dict(node_info) for node_info in all_node_info.values() - ] - - all_node_dict = await get_or_create_event_loop().run_in_executor( - self._dashboard_head._thread_pool_executor, convert, all_node_info + all_node_dicts = await get_or_create_event_loop().run_in_executor( + self._dashboard_head._thread_pool_executor, + batch_gcs_node_info_to_dict, + all_node_info, ) - for node in all_node_dict: - yield node + yield from all_node_dicts while True: try: - published = await subscriber.poll(batch_size=200) - for node_id, node_info in published: - if node_id is not None: - yield await get_or_create_event_loop().run_in_executor( - self._dashboard_head._thread_pool_executor, - gcs_node_info_to_dict, - node_info, - ) - # yield control to APIs - await asyncio.sleep(0) + published = await subscriber.poll( + batch_size=node_consts.RAY_NODE_HEAD_SUBSCRIBER_POLL_SIZE + ) + updated_dicts = await get_or_create_event_loop().run_in_executor( + self._dashboard_head._thread_pool_executor, + batch_updated_pairs_to_dict, + (node_info for _, node_info in published), + ) + yield from updated_dicts except Exception: - logger.exception("Error updating nodes from subscriber.") + logger.exception("Failed handling updated nodes.") async def _update_node(self, node: dict): node_id = node["nodeId"] # hex @@ -216,15 +220,18 @@ async def _update_node(self, node: dict): ) assert node["state"] in ["ALIVE", "DEAD"] is_alive = node["state"] == "ALIVE" - # prepare agents for alive node, and pop for dead node + # Prepare agents for alive node, and pop agents for dead node. if is_alive: if node_id not in DataSource.agents: + # Agent port is read from internal KV. Problem is it's not present when + # we receive this update; it's only present after agent.py starts + # listening. So we make an async task that periodically polls internal + # KV. asyncio.create_task(self._update_agent(node_id)) else: - # not alive DataSource.agents.pop(node_id, None) self._dead_node_queue.append(node_id) - if len(self._dead_node_queue) > MAX_NODES_TO_CACHE: + if len(self._dead_node_queue) > node_consts.MAX_DEAD_NODES_TO_CACHE: DataSource.nodes.pop(self._dead_node_queue.popleft(), None) DataSource.nodes[node_id] = node @@ -248,12 +255,13 @@ async def _update_agent(self, node_id): if agent_port: DataSource.agents[node_id] = json.loads(agent_port) return - await asyncio.sleep(1) + await asyncio.sleep(node_consts.RAY_NODE_HEAD_AGENT_POLL_INTERVAL_S) async def _update_nodes(self): """ Subscribe to node updates and update the internal states. If the head node is - not registered after FREQUENT_UPDATE_TIMEOUT_SECONDS, it logs a warning once. + not registered after RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT, it logs a + warning only once. """ warning_shown = False async for node in self._subscribe_nodes(): @@ -263,11 +271,11 @@ async def _update_nodes(self): if ( not warning_shown and (time.time() - self._module_start_time) - > FREQUENT_UPDATE_TIMEOUT_SECONDS + > RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT ): logger.warning( "Head node is not registered even after " - f"{FREQUENT_UPDATE_TIMEOUT_SECONDS} seconds. " + f"{RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT} seconds. " "The API server might not work correctly. Please " "report a Github issue. Internal states :" f"{self.get_internal_states()}" diff --git a/python/ray/dashboard/modules/node/tests/test_node.py b/python/ray/dashboard/modules/node/tests/test_node.py index b1c90645a295..70087a3b6d87 100644 --- a/python/ray/dashboard/modules/node/tests/test_node.py +++ b/python/ray/dashboard/modules/node/tests/test_node.py @@ -17,11 +17,12 @@ wait_until_server_available, ) from ray.cluster_utils import Cluster -from ray.dashboard.modules.node.node_consts import UPDATE_NODES_INTERVAL_SECONDS from ray.dashboard.tests.conftest import * # noqa logger = logging.getLogger(__name__) +UPDATE_NODES_INTERVAL_SECONDS = 5 + def test_nodes_update(enable_test_module, ray_start_with_dashboard): assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True @@ -104,7 +105,6 @@ def getpid(self): assert detail["raylet"]["isHeadNode"] is True assert "raylet" in detail["cmdline"][0] assert len(detail["workers"]) >= 2 - print(detail) assert len(detail["actors"]) == 2, detail["actors"] actor_worker_pids = set() From e148a425b4c531c53107bc526540b3549e701406 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 13:43:19 -0700 Subject: [PATCH 09/17] revert python/ray/dashboard/state_aggregator.py Signed-off-by: Ruiyang Wang --- python/ray/dashboard/state_aggregator.py | 556 +++++++++++------------ 1 file changed, 253 insertions(+), 303 deletions(-) diff --git a/python/ray/dashboard/state_aggregator.py b/python/ray/dashboard/state_aggregator.py index c4ba56c954c2..3027cf7d7f9d 100644 --- a/python/ray/dashboard/state_aggregator.py +++ b/python/ray/dashboard/state_aggregator.py @@ -241,37 +241,32 @@ async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - def transform(reply) -> ListApiResponse: - result = [] - for message in reply.actor_table_data: - data = protobuf_message_to_dict( - message=message, - fields_to_decode=[ - "actor_id", - "owner_id", - "job_id", - "node_id", - "placement_group_id", - ], - ) - result.append(data) - - num_after_truncation = len(result) + reply.num_filtered - result = self._filter(result, option.filters, ActorState, option.detail) - num_filtered = len(result) - - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["actor_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + result = [] + for message in reply.actor_table_data: + data = protobuf_message_to_dict( + message=message, + fields_to_decode=[ + "actor_id", + "owner_id", + "job_id", + "node_id", + "placement_group_id", + ], ) - - return await get_or_create_event_loop().run_in_executor( - self._thread_pool_executor, transform, reply + result.append(data) + + num_after_truncation = len(result) + reply.num_filtered + result = self._filter(result, option.filters, ActorState, option.detail) + num_filtered = len(result) + + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["actor_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse: @@ -288,35 +283,26 @@ async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiRespo except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - def transform(reply) -> ListApiResponse: - result = [] - for message in reply.placement_group_table_data: - data = protobuf_message_to_dict( - message=message, - fields_to_decode=[ - "placement_group_id", - "creator_job_id", - "node_id", - ], - ) - result.append(data) - num_after_truncation = len(result) - - result = self._filter( - result, option.filters, PlacementGroupState, option.detail - ) - num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["placement_group_id"]) - return ListApiResponse( - result=list(islice(result, option.limit)), - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + result = [] + for message in reply.placement_group_table_data: + data = protobuf_message_to_dict( + message=message, + fields_to_decode=["placement_group_id", "creator_job_id", "node_id"], ) + result.append(data) + num_after_truncation = len(result) - return await get_or_create_event_loop().run_in_executor( - self._thread_pool_executor, transform, reply + result = self._filter( + result, option.filters, PlacementGroupState, option.detail + ) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["placement_group_id"]) + return ListApiResponse( + result=list(islice(result, option.limit)), + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: @@ -333,39 +319,33 @@ async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - def transform(reply) -> ListApiResponse: - result = [] - for message in reply.node_info_list: - data = protobuf_message_to_dict( - message=message, fields_to_decode=["node_id"] - ) - data["node_ip"] = data["node_manager_address"] - data["start_time_ms"] = int(data["start_time_ms"]) - data["end_time_ms"] = int(data["end_time_ms"]) - death_info = data.get("death_info", {}) - data["state_message"] = compose_state_message( - death_info.get("reason", None), - death_info.get("reason_message", None), - ) - - result.append(data) + result = [] + for message in reply.node_info_list: + data = protobuf_message_to_dict( + message=message, fields_to_decode=["node_id"] + ) + data["node_ip"] = data["node_manager_address"] + data["start_time_ms"] = int(data["start_time_ms"]) + data["end_time_ms"] = int(data["end_time_ms"]) + death_info = data.get("death_info", {}) + data["state_message"] = compose_state_message( + death_info.get("reason", None), death_info.get("reason_message", None) + ) - num_after_truncation = len(result) + reply.num_filtered - result = self._filter(result, option.filters, NodeState, option.detail) - num_filtered = len(result) + result.append(data) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["node_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, - ) + num_after_truncation = len(result) + reply.num_filtered + result = self._filter(result, option.filters, NodeState, option.detail) + num_filtered = len(result) - return await get_or_create_event_loop().run_in_executor( - self._thread_pool_executor, transform, reply + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["node_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: @@ -383,61 +363,49 @@ async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - def transform(reply) -> ListApiResponse: - - result = [] - for message in reply.worker_table_data: - data = protobuf_message_to_dict( - message=message, fields_to_decode=["worker_id", "raylet_id"] - ) - data["worker_id"] = data["worker_address"]["worker_id"] - data["node_id"] = data["worker_address"]["raylet_id"] - data["ip"] = data["worker_address"]["ip_address"] - data["start_time_ms"] = int(data["start_time_ms"]) - data["end_time_ms"] = int(data["end_time_ms"]) - data["worker_launch_time_ms"] = int(data["worker_launch_time_ms"]) - data["worker_launched_time_ms"] = int(data["worker_launched_time_ms"]) - result.append(data) - - num_after_truncation = len(result) + reply.num_filtered - result = self._filter(result, option.filters, WorkerState, option.detail) - num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["worker_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + result = [] + for message in reply.worker_table_data: + data = protobuf_message_to_dict( + message=message, fields_to_decode=["worker_id", "raylet_id"] ) - - return await get_or_create_event_loop().run_in_executor( - self._thread_pool_executor, transform, reply + data["worker_id"] = data["worker_address"]["worker_id"] + data["node_id"] = data["worker_address"]["raylet_id"] + data["ip"] = data["worker_address"]["ip_address"] + data["start_time_ms"] = int(data["start_time_ms"]) + data["end_time_ms"] = int(data["end_time_ms"]) + data["worker_launch_time_ms"] = int(data["worker_launch_time_ms"]) + data["worker_launched_time_ms"] = int(data["worker_launched_time_ms"]) + result.append(data) + + num_after_truncation = len(result) + reply.num_filtered + result = self._filter(result, option.filters, WorkerState, option.detail) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["worker_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) async def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse: try: - reply = await self._client.get_job_info(timeout=option.timeout) - except DataSourceUnavailable: - raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - - def transform(reply) -> ListApiResponse: - result = [job.dict() for job in reply] + result = await self._client.get_job_info(timeout=option.timeout) + result = [job.dict() for job in result] total = len(result) result = self._filter(result, option.filters, JobState, option.detail) num_filtered = len(result) result.sort(key=lambda entry: entry["job_id"] or "") result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=total, - num_after_truncation=total, - num_filtered=num_filtered, - ) - - return await get_or_create_event_loop().run_in_executor( - self._thread_pool_executor, transform, reply + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + return ListApiResponse( + result=result, + total=total, + num_after_truncation=total, + num_filtered=num_filtered, ) async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: @@ -456,10 +424,12 @@ async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - def transform(reply) -> ListApiResponse: + def transform(reply): """ Transforms from proto to dict, applies filters, sorts, and truncates. This function is executed in a separate thread. + + Returns the ListApiResponse. """ result = [ protobuf_to_task_state_dict(message) for message in reply.events_by_task @@ -504,90 +474,85 @@ async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: return_exceptions=True, ) - def transform(replies) -> ListApiResponse: - unresponsive_nodes = 0 - worker_stats = [] - total_objects = 0 - for reply, _ in zip(replies, raylet_ids): - if isinstance(reply, DataSourceUnavailable): - unresponsive_nodes += 1 - continue - elif isinstance(reply, Exception): - raise reply - - total_objects += reply.total - for core_worker_stat in reply.core_workers_stats: - # NOTE: Set preserving_proto_field_name=False here because - # `construct_memory_table` requires a dictionary that has - # modified protobuf name - # (e.g., workerId instead of worker_id) as a key. - worker_stats.append( - protobuf_message_to_dict( - message=core_worker_stat, - fields_to_decode=["object_id"], - preserving_proto_field_name=False, - ) + unresponsive_nodes = 0 + worker_stats = [] + total_objects = 0 + for reply, _ in zip(replies, raylet_ids): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + total_objects += reply.total + for core_worker_stat in reply.core_workers_stats: + # NOTE: Set preserving_proto_field_name=False here because + # `construct_memory_table` requires a dictionary that has + # modified protobuf name + # (e.g., workerId instead of worker_id) as a key. + worker_stats.append( + protobuf_message_to_dict( + message=core_worker_stat, + fields_to_decode=["object_id"], + preserving_proto_field_name=False, ) - - partial_failure_warning = None - if len(raylet_ids) > 0 and unresponsive_nodes > 0: - warning_msg = NODE_QUERY_FAILURE_WARNING.format( - type="raylet", - total=len(raylet_ids), - network_failures=unresponsive_nodes, - log_command="raylet.out", - ) - if unresponsive_nodes == len(raylet_ids): - raise DataSourceUnavailable(warning_msg) - partial_failure_warning = ( - f"The returned data may contain incomplete result. {warning_msg}" ) - result = [] - memory_table = memory_utils.construct_memory_table(worker_stats) - for entry in memory_table.table: - data = entry.as_dict() - # `construct_memory_table` returns object_ref field which is indeed - # object_id. We do transformation here. - # TODO(sang): Refactor `construct_memory_table`. - data["object_id"] = data["object_ref"] - del data["object_ref"] - data["ip"] = data["node_ip_address"] - del data["node_ip_address"] - data["type"] = data["type"].upper() - data["task_status"] = ( - "NIL" if data["task_status"] == "-" else data["task_status"] - ) - result.append(data) - - # Add callsite warnings if it is not configured. - callsite_warning = [] - callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0) - if not callsite_enabled: - callsite_warning.append( - "Callsite is not being recorded. " - "To record callsite information for each ObjectRef created, set " - "env variable RAY_record_ref_creation_sites=1 during `ray start` " - "and `ray.init`." - ) + partial_failure_warning = None + if len(raylet_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="raylet", + total=len(raylet_ids), + network_failures=unresponsive_nodes, + log_command="raylet.out", + ) + if unresponsive_nodes == len(raylet_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) - num_after_truncation = len(result) - result = self._filter(result, option.filters, ObjectState, option.detail) - num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["object_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - partial_failure_warning=partial_failure_warning, - total=total_objects, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, - warnings=callsite_warning, + result = [] + memory_table = memory_utils.construct_memory_table(worker_stats) + for entry in memory_table.table: + data = entry.as_dict() + # `construct_memory_table` returns object_ref field which is indeed + # object_id. We do transformation here. + # TODO(sang): Refactor `construct_memory_table`. + data["object_id"] = data["object_ref"] + del data["object_ref"] + data["ip"] = data["node_ip_address"] + del data["node_ip_address"] + data["type"] = data["type"].upper() + data["task_status"] = ( + "NIL" if data["task_status"] == "-" else data["task_status"] + ) + result.append(data) + + # Add callsite warnings if it is not configured. + callsite_warning = [] + callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0) + if not callsite_enabled: + callsite_warning.append( + "Callsite is not being recorded. " + "To record callsite information for each ObjectRef created, set " + "env variable RAY_record_ref_creation_sites=1 during `ray start` " + "and `ray.init`." ) - return await get_or_create_event_loop().run_in_executor( - self._thread_pool_executor, transform, replies + num_after_truncation = len(result) + result = self._filter(result, option.filters, ObjectState, option.detail) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["object_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + partial_failure_warning=partial_failure_warning, + total=total_objects, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + warnings=callsite_warning, ) async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: @@ -609,73 +574,66 @@ async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: return_exceptions=True, ) - def transform(replies) -> ListApiResponse: - result = [] - unresponsive_nodes = 0 - total_runtime_envs = 0 - for node_id, reply in zip( - self._client.get_all_registered_runtime_env_agent_ids(), replies - ): - if isinstance(reply, DataSourceUnavailable): - unresponsive_nodes += 1 - continue - elif isinstance(reply, Exception): - raise reply - - total_runtime_envs += reply.total - states = reply.runtime_env_states - for state in states: - data = protobuf_message_to_dict(message=state, fields_to_decode=[]) - # Need to deserialize this field. - data["runtime_env"] = RuntimeEnv.deserialize( - data["runtime_env"] - ).to_dict() - data["node_id"] = node_id - result.append(data) - - partial_failure_warning = None - if len(agent_ids) > 0 and unresponsive_nodes > 0: - warning_msg = NODE_QUERY_FAILURE_WARNING.format( - type="agent", - total=len(agent_ids), - network_failures=unresponsive_nodes, - log_command="dashboard_agent.log", - ) - if unresponsive_nodes == len(agent_ids): - raise DataSourceUnavailable(warning_msg) - partial_failure_warning = ( - f"The returned data may contain incomplete result. {warning_msg}" - ) - num_after_truncation = len(result) - result = self._filter( - result, option.filters, RuntimeEnvState, option.detail - ) - num_filtered = len(result) - - # Sort to make the output deterministic. - def sort_func(entry): - # If creation time is not there yet (runtime env is failed - # to be created or not created yet, they are the highest priority. - # Otherwise, "bigger" creation time is coming first. - if "creation_time_ms" not in entry: - return float("inf") - elif entry["creation_time_ms"] is None: - return float("inf") - else: - return float(entry["creation_time_ms"]) + result = [] + unresponsive_nodes = 0 + total_runtime_envs = 0 + for node_id, reply in zip( + self._client.get_all_registered_runtime_env_agent_ids(), replies + ): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + total_runtime_envs += reply.total + states = reply.runtime_env_states + for state in states: + data = protobuf_message_to_dict(message=state, fields_to_decode=[]) + # Need to deserialize this field. + data["runtime_env"] = RuntimeEnv.deserialize( + data["runtime_env"] + ).to_dict() + data["node_id"] = node_id + result.append(data) - result.sort(key=sort_func, reverse=True) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - partial_failure_warning=partial_failure_warning, - total=total_runtime_envs, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + partial_failure_warning = None + if len(agent_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="agent", + total=len(agent_ids), + network_failures=unresponsive_nodes, + log_command="dashboard_agent.log", ) - - return await get_or_create_event_loop().run_in_executor( - self._thread_pool_executor, transform, replies + if unresponsive_nodes == len(agent_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + num_after_truncation = len(result) + result = self._filter(result, option.filters, RuntimeEnvState, option.detail) + num_filtered = len(result) + + # Sort to make the output deterministic. + def sort_func(entry): + # If creation time is not there yet (runtime env is failed + # to be created or not created yet, they are the highest priority. + # Otherwise, "bigger" creation time is coming first. + if "creation_time_ms" not in entry: + return float("inf") + elif entry["creation_time_ms"] is None: + return float("inf") + else: + return float(entry["creation_time_ms"]) + + result.sort(key=sort_func, reverse=True) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + partial_failure_warning=partial_failure_warning, + total=total_runtime_envs, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) async def list_cluster_events(self, *, option: ListApiOptions) -> ListApiResponse: @@ -686,33 +644,25 @@ async def list_cluster_events(self, *, option: ListApiOptions) -> ListApiRespons The schema of returned "dict" is equivalent to the `ClusterEventState` protobuf message. """ - reply = await self._client.get_all_cluster_events() - - def transform(reply) -> ListApiResponse: - result = [] - for _, events in reply.items(): - for _, event in events.items(): - event["time"] = str(datetime.fromtimestamp(int(event["timestamp"]))) - result.append(event) - - num_after_truncation = len(result) - result.sort(key=lambda entry: entry["timestamp"]) - total = len(result) - result = self._filter( - result, option.filters, ClusterEventState, option.detail - ) - num_filtered = len(result) - # Sort to make the output deterministic. - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, - ) - - return await get_or_create_event_loop().run_in_executor( - self._thread_pool_executor, transform, reply + result = [] + all_events = await self._client.get_all_cluster_events() + for _, events in all_events.items(): + for _, event in events.items(): + event["time"] = str(datetime.fromtimestamp(int(event["timestamp"]))) + result.append(event) + + num_after_truncation = len(result) + result.sort(key=lambda entry: entry["timestamp"]) + total = len(result) + result = self._filter(result, option.filters, ClusterEventState, option.detail) + num_filtered = len(result) + # Sort to make the output deterministic. + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse: From ee139fc9d44a02b34f7e2024af50eb743e0c5c81 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 13:43:55 -0700 Subject: [PATCH 10/17] revert python/ray/includes/gcs_client.pxi Signed-off-by: Ruiyang Wang --- python/ray/includes/gcs_client.pxi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/includes/gcs_client.pxi b/python/ray/includes/gcs_client.pxi index 4d91a6cd0e4a..d3850613ccdf 100644 --- a/python/ray/includes/gcs_client.pxi +++ b/python/ray/includes/gcs_client.pxi @@ -78,7 +78,7 @@ cdef class NewGcsClient: @property def cluster_id(self) -> ray.ClusterID: cdef CClusterID cluster_id = self.inner.get().GetClusterId() - return ray.ClusterID(cluster_id.Binary()) + return ray.ClusterID.from_binary(cluster_id.Binary()) ############################################################# # Internal KV sync methods @@ -612,7 +612,7 @@ cdef convert_get_all_node_info( for b in serialized_reply: proto = gcs_pb2.GcsNodeInfo() proto.ParseFromString(b) - node_table_data[NodeID(proto.node_id)] = proto + node_table_data[NodeID.from_binary(proto.node_id)] = proto return node_table_data, None except Exception as e: return None, e @@ -634,7 +634,7 @@ cdef convert_get_all_job_info( for b in serialized_reply: proto = gcs_pb2.JobTableData() proto.ParseFromString(b) - job_table_data[JobID(proto.job_id)] = proto + job_table_data[JobID.from_binary(proto.job_id)] = proto return job_table_data, None except Exception as e: return None, e @@ -653,7 +653,7 @@ cdef convert_get_all_actor_info( for b in serialized_reply: proto = gcs_pb2.ActorTableData() proto.ParseFromString(b) - actor_table_data[ActorID(proto.actor_id)] = proto + actor_table_data[ActorID.from_binary(proto.actor_id)] = proto return actor_table_data, None except Exception as e: return None, e From a1d5b3c694211c5b815d01d4e6fad033141a845f Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 13:45:30 -0700 Subject: [PATCH 11/17] fix Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/node/node_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 5d45b4714d35..8c7e7e949c51 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -198,7 +198,7 @@ async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: updated_dicts = await get_or_create_event_loop().run_in_executor( self._dashboard_head._thread_pool_executor, batch_updated_pairs_to_dict, - (node_info for _, node_info in published), + published, ) yield from updated_dicts except Exception: From 48032083f6e69661523c8907efb62f8b2c3278e6 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 14:09:18 -0700 Subject: [PATCH 12/17] no yield from in async funcs Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/node/node_head.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 8c7e7e949c51..2a6f514a184b 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -188,7 +188,8 @@ async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: batch_gcs_node_info_to_dict, all_node_info, ) - yield from all_node_dicts + for node in all_node_dicts: + yield node while True: try: @@ -200,7 +201,8 @@ async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: batch_updated_pairs_to_dict, published, ) - yield from updated_dicts + for node in updated_dicts: + yield node except Exception: logger.exception("Failed handling updated nodes.") From 6de587d2442d5bd017be9305efc9af33399fbb49 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 14:21:16 -0700 Subject: [PATCH 13/17] fixes Signed-off-by: Ruiyang Wang --- python/ray/_private/gcs_pubsub.py | 8 ++-- .../ray/dashboard/modules/actor/actor_head.py | 34 ++++++++-------- .../ray/dashboard/modules/node/node_consts.py | 4 +- .../ray/dashboard/modules/node/node_head.py | 39 ++++++++++--------- .../dashboard/modules/node/tests/test_node.py | 28 ------------- 5 files changed, 45 insertions(+), 68 deletions(-) diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index 757fb6e51fa0..27d53c9763b1 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -255,7 +255,7 @@ def queue_size(self): return len(self._queue) async def poll( - self, timeout=None, batch_size=500 + self, batch_size, timeout=None ) -> List[Tuple[bytes, gcs_pb2.ActorTableData]]: """Polls for new actor message. @@ -266,7 +266,7 @@ async def poll( return self._pop_actors(self._queue, batch_size=batch_size) @staticmethod - def _pop_actors(queue, batch_size=100): + def _pop_actors(queue, batch_size): if len(queue) == 0: return [] popped = 0 @@ -288,7 +288,7 @@ def __init__( super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel) async def poll( - self, timeout=None, batch_size=100 + self, batch_size, timeout=None ) -> List[Tuple[bytes, gcs_pb2.GcsNodeInfo]]: """Polls for new node info message. @@ -299,7 +299,7 @@ async def poll( return self._pop_node_infos(self._queue, batch_size=batch_size) @staticmethod - def _pop_node_infos(queue, batch_size=100): + def _pop_node_infos(queue, batch_size): if len(queue) == 0: return [] popped = 0 diff --git a/python/ray/dashboard/modules/actor/actor_head.py b/python/ray/dashboard/modules/actor/actor_head.py index 223a94cb49bf..a0033fb005cd 100644 --- a/python/ray/dashboard/modules/actor/actor_head.py +++ b/python/ray/dashboard/modules/actor/actor_head.py @@ -11,7 +11,6 @@ import ray.dashboard.utils as dashboard_utils from ray import ActorID from ray._private.gcs_pubsub import GcsAioActorSubscriber -from ray._private.utils import get_or_create_event_loop from ray.core.generated import gcs_pb2, gcs_service_pb2, gcs_service_pb2_grpc from ray.dashboard.datacenter import DataOrganizer, DataSource from ray.dashboard.modules.actor import actor_consts @@ -137,22 +136,30 @@ def __init__(self, dashboard_head): self.accumulative_event_processing_s = 0 async def _update_actors(self): + """ + Yields actor info. First yields all actors from GCS, then subscribes to + actor updates. + + To prevent Time-of-check to time-of-use issue [1], the get-all-actor-info + happens after the subscription. That is, an update between get-all-actor-info + and the subscription is not missed. + # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use + """ + # Receive actors from channel. + gcs_addr = self._dashboard_head.gcs_address + subscriber = GcsAioActorSubscriber(address=gcs_addr) + await subscriber.subscribe() + # Get all actor info. while True: try: logger.info("Getting all actor info from GCS.") actors = await self.get_all_actor_info(timeout=5) - - def convert(actors) -> Dict[str, dict]: - return { - actor_id.hex(): actor_table_data_to_dict(actor_table_data) - for actor_id, actor_table_data in actors.items() - } - - actor_dicts = await get_or_create_event_loop().run_in_executor( - self._dashboard_head._thread_pool_executor, convert, actors - ) + actor_dicts: Dict[str, dict] = { + actor_id.hex(): actor_table_data_to_dict(actor_table_data) + for actor_id, actor_table_data in actors.items() + } # Update actors. DataSource.actors.reset(actor_dicts) # Update node actors and job actors. @@ -205,11 +212,6 @@ def process_actor_data_from_pubsub(actor_id, actor_table_data): node_actors[actor_id] = actor_table_data DataSource.node_actors[node_id] = node_actors - # Receive actors from channel. - gcs_addr = self._dashboard_head.gcs_address - subscriber = GcsAioActorSubscriber(address=gcs_addr) - await subscriber.subscribe() - while True: try: published = await subscriber.poll(batch_size=200) diff --git a/python/ray/dashboard/modules/node/node_consts.py b/python/ray/dashboard/modules/node/node_consts.py index 941aecf3f8ee..0b6c8aff9999 100644 --- a/python/ray/dashboard/modules/node/node_consts.py +++ b/python/ray/dashboard/modules/node/node_consts.py @@ -3,8 +3,8 @@ NODE_STATS_UPDATE_INTERVAL_SECONDS = env_integer( "NODE_STATS_UPDATE_INTERVAL_SECONDS", 5 ) -RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT = env_integer( - "RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT", 10 +RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT = env_integer( + "RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT", 10 ) MAX_COUNT_OF_GCS_RPC_ERROR = 10 # This is consistent with gcs_node_manager.cc diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 2a6f514a184b..b34285198c88 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -35,7 +35,7 @@ from ray.dashboard.datacenter import DataOrganizer, DataSource from ray.dashboard.modules.node import node_consts from ray.dashboard.modules.node.node_consts import ( - RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT, + RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT, ) from ray.dashboard.utils import async_loop_forever @@ -43,20 +43,20 @@ routes = dashboard_optional_utils.DashboardHeadRouteTable -def gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict: +def _gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict: return dashboard_utils.message_to_dict( message, {"nodeId"}, always_print_fields_with_no_presence=True ) -def batch_gcs_node_info_to_dict(messages: List[gcs_pb2.GcsNodeInfo]) -> List[dict]: - return [gcs_node_info_to_dict(message) for message in messages] +def _batch__gcs_node_info_to_dict(messages: List[gcs_pb2.GcsNodeInfo]) -> List[dict]: + return [_gcs_node_info_to_dict(message) for message in messages] -def batch_updated_pairs_to_dict( +def _batch_updated_pairs_to_dict( messages: List[Tuple[bytes, gcs_pb2.GcsNodeInfo]] ) -> List[dict]: - return [gcs_node_info_to_dict(node_info) for node_id_bytes, node_info in messages] + return [_gcs_node_info_to_dict(node_info) for _, node_info in messages] def node_stats_to_dict(message): @@ -169,7 +169,7 @@ def get_internal_states(self): "module_lifetime_s": time.time() - self._module_start_time, } - async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: + async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]: """ Yields the initial state of all nodes, then yields the updated state of nodes. @@ -180,12 +180,15 @@ async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: subscriber = GcsAioNodeInfoSubscriber(address=gcs_addr) await subscriber.subscribe() - # Get all node info from GCS. For TOCTOU, it happens after the subscription. + # Get all node info from GCS. To prevent Time-of-check to time-of-use issue [1], + # it happens after the subscription. That is, an update between + # get-all-node-info and the subscription is not missed. + # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use all_node_info = await self.get_all_node_info(timeout=GCS_RPC_TIMEOUT_SECONDS) all_node_dicts = await get_or_create_event_loop().run_in_executor( self._dashboard_head._thread_pool_executor, - batch_gcs_node_info_to_dict, + _batch__gcs_node_info_to_dict, all_node_info, ) for node in all_node_dicts: @@ -198,7 +201,7 @@ async def _subscribe_nodes(self) -> AsyncGenerator[dict, None]: ) updated_dicts = await get_or_create_event_loop().run_in_executor( self._dashboard_head._thread_pool_executor, - batch_updated_pairs_to_dict, + _batch_updated_pairs_to_dict, published, ) for node in updated_dicts: @@ -225,10 +228,10 @@ async def _update_node(self, node: dict): # Prepare agents for alive node, and pop agents for dead node. if is_alive: if node_id not in DataSource.agents: - # Agent port is read from internal KV. Problem is it's not present when - # we receive this update; it's only present after agent.py starts - # listening. So we make an async task that periodically polls internal - # KV. + # Agent port is read from internal KV, which is only populated + # upon Agent startup. In case this update received before agent + # fully started up, we schedule a task to asynchronously update + # DataSource with appropriate agent port. asyncio.create_task(self._update_agent(node_id)) else: DataSource.agents.pop(node_id, None) @@ -262,22 +265,22 @@ async def _update_agent(self, node_id): async def _update_nodes(self): """ Subscribe to node updates and update the internal states. If the head node is - not registered after RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT, it logs a + not registered after RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT, it logs a warning only once. """ warning_shown = False - async for node in self._subscribe_nodes(): + async for node in self._subscribe_for_node_updates(): await self._update_node(node) if not self._head_node_registration_time_s: # head node is not registered yet if ( not warning_shown and (time.time() - self._module_start_time) - > RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT + > RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT ): logger.warning( "Head node is not registered even after " - f"{RAY_NODE_HEAD_HEAD_NODE_REGISTRATION_TIMEOUT} seconds. " + f"{RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT} seconds. " "The API server might not work correctly. Please " "report a Github issue. Internal states :" f"{self.get_internal_states()}" diff --git a/python/ray/dashboard/modules/node/tests/test_node.py b/python/ray/dashboard/modules/node/tests/test_node.py index 70087a3b6d87..d3f11140e94f 100644 --- a/python/ray/dashboard/modules/node/tests/test_node.py +++ b/python/ray/dashboard/modules/node/tests/test_node.py @@ -21,8 +21,6 @@ logger = logging.getLogger(__name__) -UPDATE_NODES_INTERVAL_SECONDS = 5 - def test_nodes_update(enable_test_module, ray_start_with_dashboard): assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True @@ -229,31 +227,5 @@ def get_nodes(): time.sleep(2) -@pytest.mark.parametrize( - "ray_start_cluster_head", [{"include_dashboard": True}], indirect=True -) -def test_frequent_node_update( - enable_test_module, disable_aiohttp_cache, ray_start_cluster_head -): - cluster: Cluster = ray_start_cluster_head - assert wait_until_server_available(cluster.webui_url) - webui_url = cluster.webui_url - webui_url = format_web_url(webui_url) - - def verify(): - response = requests.get(webui_url + "/internal/node_module") - response.raise_for_status() - result = response.json() - data = result["data"] - head_node_registration_time = data["headNodeRegistrationTimeS"] - # If the head node is not registered, it is None. - assert head_node_registration_time is not None - # Head node should be registered before the node update interval - # because we do frequent until the head node is registered. - return head_node_registration_time < UPDATE_NODES_INTERVAL_SECONDS - - wait_for_condition(verify, timeout=15) - - if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) From 085eebbcb9fe8a284a14bced262ad29dec5141de Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 14:59:24 -0700 Subject: [PATCH 14/17] nits Signed-off-by: Ruiyang Wang --- .../ray/dashboard/modules/node/node_head.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index b34285198c88..2018fbf65c57 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -49,7 +49,7 @@ def _gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict: ) -def _batch__gcs_node_info_to_dict(messages: List[gcs_pb2.GcsNodeInfo]) -> List[dict]: +def _batch_gcs_node_info_to_dict(messages: List[gcs_pb2.GcsNodeInfo]) -> List[dict]: return [_gcs_node_info_to_dict(message) for message in messages] @@ -184,11 +184,11 @@ async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]: # it happens after the subscription. That is, an update between # get-all-node-info and the subscription is not missed. # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use - all_node_info = await self.get_all_node_info(timeout=GCS_RPC_TIMEOUT_SECONDS) + all_node_info = await self.get_all_node_info(timeout=-1) all_node_dicts = await get_or_create_event_loop().run_in_executor( self._dashboard_head._thread_pool_executor, - _batch__gcs_node_info_to_dict, + _batch_gcs_node_info_to_dict, all_node_info, ) for node in all_node_dicts: @@ -248,18 +248,22 @@ async def _update_agent(self, node_id): """ key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}".encode() while True: - agent_port = await self._gcs_aio_client.internal_kv_get( - key, - namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - timeout=GCS_RPC_TIMEOUT_SECONDS, - ) - # The node may be dead already. Only update DataSource.agents if the node is - # still alive. - if DataSource.nodes.get(node_id, {}).get("state") != "ALIVE": - return - if agent_port: - DataSource.agents[node_id] = json.loads(agent_port) - return + try: + agent_port = await self._gcs_aio_client.internal_kv_get( + key, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + timeout=-1, + ) + # The node may be dead already. Only update DataSource.agents if the + # node is still alive. + if DataSource.nodes.get(node_id, {}).get("state") != "ALIVE": + return + if agent_port: + DataSource.agents[node_id] = json.loads(agent_port) + return + except Exception: + logger.exception(f"Error getting agent port for node {node_id}.") + await asyncio.sleep(node_consts.RAY_NODE_HEAD_AGENT_POLL_INTERVAL_S) async def _update_nodes(self): From c32cf2f86a5117fdc453d504a89971b6ef982a8d Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 15:04:11 -0700 Subject: [PATCH 15/17] fix Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/node/node_head.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 2018fbf65c57..d4f6abde3688 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -49,11 +49,13 @@ def _gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict: ) -def _batch_gcs_node_info_to_dict(messages: List[gcs_pb2.GcsNodeInfo]) -> List[dict]: - return [_gcs_node_info_to_dict(message) for message in messages] +def _map_batch_node_info_to_dict( + messages: Dict[NodeID, gcs_pb2.GcsNodeInfo] +) -> List[dict]: + return [_gcs_node_info_to_dict(message) for message in messages.values()] -def _batch_updated_pairs_to_dict( +def _list_gcs_node_info_to_dict( messages: List[Tuple[bytes, gcs_pb2.GcsNodeInfo]] ) -> List[dict]: return [_gcs_node_info_to_dict(node_info) for _, node_info in messages] @@ -184,11 +186,11 @@ async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]: # it happens after the subscription. That is, an update between # get-all-node-info and the subscription is not missed. # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use - all_node_info = await self.get_all_node_info(timeout=-1) + all_node_info = await self.get_all_node_info(timeout=None) all_node_dicts = await get_or_create_event_loop().run_in_executor( self._dashboard_head._thread_pool_executor, - _batch_gcs_node_info_to_dict, + _map_batch_node_info_to_dict, all_node_info, ) for node in all_node_dicts: @@ -201,7 +203,7 @@ async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]: ) updated_dicts = await get_or_create_event_loop().run_in_executor( self._dashboard_head._thread_pool_executor, - _batch_updated_pairs_to_dict, + _list_gcs_node_info_to_dict, published, ) for node in updated_dicts: @@ -252,7 +254,7 @@ async def _update_agent(self, node_id): agent_port = await self._gcs_aio_client.internal_kv_get( key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - timeout=-1, + timeout=None, ) # The node may be dead already. Only update DataSource.agents if the # node is still alive. From f9bae64c6ad22339c1f69b050973f3bc12b3e678 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 16:02:30 -0700 Subject: [PATCH 16/17] update naming and comments Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/actor/actor_head.py | 5 +++-- python/ray/dashboard/modules/node/node_consts.py | 8 ++++---- python/ray/dashboard/modules/node/node_head.py | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/ray/dashboard/modules/actor/actor_head.py b/python/ray/dashboard/modules/actor/actor_head.py index a0033fb005cd..4ce372a751a3 100644 --- a/python/ray/dashboard/modules/actor/actor_head.py +++ b/python/ray/dashboard/modules/actor/actor_head.py @@ -137,8 +137,9 @@ def __init__(self, dashboard_head): async def _update_actors(self): """ - Yields actor info. First yields all actors from GCS, then subscribes to - actor updates. + Processes actor info. First gets all actors from GCS, then subscribes to + actor updates. For each actor update, updates DataSource.node_actors and + DataSource.actors. To prevent Time-of-check to time-of-use issue [1], the get-all-actor-info happens after the subscription. That is, an update between get-all-actor-info diff --git a/python/ray/dashboard/modules/node/node_consts.py b/python/ray/dashboard/modules/node/node_consts.py index 0b6c8aff9999..c3939cd66e3b 100644 --- a/python/ray/dashboard/modules/node/node_consts.py +++ b/python/ray/dashboard/modules/node/node_consts.py @@ -9,9 +9,9 @@ MAX_COUNT_OF_GCS_RPC_ERROR = 10 # This is consistent with gcs_node_manager.cc MAX_DEAD_NODES_TO_CACHE = env_integer("RAY_maximum_gcs_dead_node_cached_count", 1000) -RAY_NODE_HEAD_SUBSCRIBER_POLL_SIZE = env_integer( - "RAY_NODE_HEAD_SUBSCRIBER_POLL_SIZE", 200 +RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE = env_integer( + "RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE", 200 ) -RAY_NODE_HEAD_AGENT_POLL_INTERVAL_S = env_integer( - "RAY_NODE_HEAD_AGENT_POLL_INTERVAL_S", 1 +RAY_DASHBOARD_AGENT_POLL_INTERVAL_S = env_integer( + "RAY_DASHBOARD_AGENT_POLL_INTERVAL_S", 1 ) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index d4f6abde3688..caf098a95b27 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -199,7 +199,7 @@ async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]: while True: try: published = await subscriber.poll( - batch_size=node_consts.RAY_NODE_HEAD_SUBSCRIBER_POLL_SIZE + batch_size=node_consts.RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE ) updated_dicts = await get_or_create_event_loop().run_in_executor( self._dashboard_head._thread_pool_executor, @@ -266,7 +266,7 @@ async def _update_agent(self, node_id): except Exception: logger.exception(f"Error getting agent port for node {node_id}.") - await asyncio.sleep(node_consts.RAY_NODE_HEAD_AGENT_POLL_INTERVAL_S) + await asyncio.sleep(node_consts.RAY_DASHBOARD_AGENT_POLL_INTERVAL_S) async def _update_nodes(self): """ From d474b323d9e2602437f8ed49d1262cfa361c1608 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 28 Aug 2024 22:00:12 -0700 Subject: [PATCH 17/17] fix conftest Signed-off-by: Ruiyang Wang --- python/ray/dashboard/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/dashboard/tests/conftest.py b/python/ray/dashboard/tests/conftest.py index 29663e422e88..511276761a03 100644 --- a/python/ray/dashboard/tests/conftest.py +++ b/python/ray/dashboard/tests/conftest.py @@ -3,6 +3,7 @@ import pytest +import ray.dashboard.modules # noqa from ray.tests.conftest import * # noqa