diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index fb4ea8d1bdd1e..27d53c9763b12 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -13,9 +13,9 @@ from grpc.experimental import aio as aiogrpc import ray._private.gcs_utils as gcs_utils -import ray._private.logging_utils as logging_utils from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2 +from ray.core.generated import gcs_pb2 from ray.core.generated import common_pb2 from ray.core.generated import pubsub_pb2 @@ -90,39 +90,6 @@ def _should_terminate_polling(e: grpc.RpcError) -> None: return True return False - @staticmethod - def _pop_error_info(queue): - if len(queue) == 0: - return None, None - msg = queue.popleft() - return msg.key_id, msg.error_info_message - - @staticmethod - def _pop_log_batch(queue): - if len(queue) == 0: - return None - msg = queue.popleft() - return logging_utils.log_batch_proto_to_dict(msg.log_batch_message) - - @staticmethod - def _pop_resource_usage(queue): - if len(queue) == 0: - return None, None - msg = queue.popleft() - return msg.key_id.decode(), msg.node_resource_usage_message.json - - @staticmethod - def _pop_actors(queue, batch_size=100): - if len(queue) == 0: - return [] - popped = 0 - msgs = [] - while len(queue) > 0 and popped < batch_size: - msg = queue.popleft() - msgs.append((msg.key_id, msg.actor_message)) - popped += 1 - return msgs - class GcsAioPublisher(_PublisherBase): """Publisher to GCS. Uses async io.""" @@ -222,7 +189,7 @@ async def _poll(self, timeout=None) -> None: self._max_processed_sequence_id = 0 for msg in poll.result().pub_messages: if msg.sequence_id <= self._max_processed_sequence_id: - logger.warn(f"Ignoring out of order message {msg}") + logger.warning(f"Ignoring out of order message {msg}") continue self._max_processed_sequence_id = msg.sequence_id self._queue.append(msg) @@ -266,6 +233,13 @@ async def poll(self, timeout=None) -> Tuple[bytes, str]: await self._poll(timeout=timeout) return self._pop_resource_usage(self._queue) + @staticmethod + def _pop_resource_usage(queue): + if len(queue) == 0: + return None, None + msg = queue.popleft() + return msg.key_id.decode(), msg.node_resource_usage_message.json + class GcsAioActorSubscriber(_AioSubscriber): def __init__( @@ -280,11 +254,58 @@ def __init__( def queue_size(self): return len(self._queue) - async def poll(self, timeout=None, batch_size=500) -> List[Tuple[bytes, str]]: + async def poll( + self, batch_size, timeout=None + ) -> List[Tuple[bytes, gcs_pb2.ActorTableData]]: """Polls for new actor message. Returns: - A tuple of binary actor ID and actor table data. + A list of tuples of binary actor ID and actor table data. """ await self._poll(timeout=timeout) return self._pop_actors(self._queue, batch_size=batch_size) + + @staticmethod + def _pop_actors(queue, batch_size): + if len(queue) == 0: + return [] + popped = 0 + msgs = [] + while len(queue) > 0 and popped < batch_size: + msg = queue.popleft() + msgs.append((msg.key_id, msg.actor_message)) + popped += 1 + return msgs + + +class GcsAioNodeInfoSubscriber(_AioSubscriber): + def __init__( + self, + worker_id: bytes = None, + address: str = None, + channel: grpc.Channel = None, + ): + super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel) + + async def poll( + self, batch_size, timeout=None + ) -> List[Tuple[bytes, gcs_pb2.GcsNodeInfo]]: + """Polls for new node info message. + + Returns: + A list of tuples of (node_id, GcsNodeInfo). + """ + await self._poll(timeout=timeout) + return self._pop_node_infos(self._queue, batch_size=batch_size) + + @staticmethod + def _pop_node_infos(queue, batch_size): + if len(queue) == 0: + return [] + popped = 0 + msgs = [] + while len(queue) > 0 and popped < batch_size: + msg = queue.popleft() + msgs.append((msg.key_id, msg.node_info_message)) + popped += 1 + return msgs diff --git a/python/ray/dashboard/modules/actor/actor_head.py b/python/ray/dashboard/modules/actor/actor_head.py index fb0aa5ffd89d2..4ce372a751a38 100644 --- a/python/ray/dashboard/modules/actor/actor_head.py +++ b/python/ray/dashboard/modules/actor/actor_head.py @@ -136,6 +136,21 @@ def __init__(self, dashboard_head): self.accumulative_event_processing_s = 0 async def _update_actors(self): + """ + Processes actor info. First gets all actors from GCS, then subscribes to + actor updates. For each actor update, updates DataSource.node_actors and + DataSource.actors. + + To prevent Time-of-check to time-of-use issue [1], the get-all-actor-info + happens after the subscription. That is, an update between get-all-actor-info + and the subscription is not missed. + # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use + """ + # Receive actors from channel. + gcs_addr = self._dashboard_head.gcs_address + subscriber = GcsAioActorSubscriber(address=gcs_addr) + await subscriber.subscribe() + # Get all actor info. while True: try: @@ -198,11 +213,6 @@ def process_actor_data_from_pubsub(actor_id, actor_table_data): node_actors[actor_id] = actor_table_data DataSource.node_actors[node_id] = node_actors - # Receive actors from channel. - gcs_addr = self._dashboard_head.gcs_address - subscriber = GcsAioActorSubscriber(address=gcs_addr) - await subscriber.subscribe() - while True: try: published = await subscriber.poll(batch_size=200) diff --git a/python/ray/dashboard/modules/node/node_consts.py b/python/ray/dashboard/modules/node/node_consts.py index 0e9b2465da952..c3939cd66e3b3 100644 --- a/python/ray/dashboard/modules/node/node_consts.py +++ b/python/ray/dashboard/modules/node/node_consts.py @@ -1,16 +1,17 @@ -from ray._private.ray_constants import env_float, env_integer +from ray._private.ray_constants import env_integer NODE_STATS_UPDATE_INTERVAL_SECONDS = env_integer( "NODE_STATS_UPDATE_INTERVAL_SECONDS", 5 ) -UPDATE_NODES_INTERVAL_SECONDS = env_integer("UPDATE_NODES_INTERVAL_SECONDS", 5) -# Until the head node is registered, -# the API server is doing more frequent update -# with this interval. -FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS = env_float( - "FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS", 0.1 +RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT = env_integer( + "RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT", 10 ) -# If the head node is not updated within -# this timeout, it will stop frequent update. -FREQUENT_UPDATE_TIMEOUT_SECONDS = env_integer("FREQUENT_UPDATE_TIMEOUT_SECONDS", 10) MAX_COUNT_OF_GCS_RPC_ERROR = 10 +# This is consistent with gcs_node_manager.cc +MAX_DEAD_NODES_TO_CACHE = env_integer("RAY_maximum_gcs_dead_node_cached_count", 1000) +RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE = env_integer( + "RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE", 200 +) +RAY_DASHBOARD_AGENT_POLL_INTERVAL_S = env_integer( + "RAY_DASHBOARD_AGENT_POLL_INTERVAL_S", 1 +) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index cf521fa22b0d6..caf098a95b27d 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -3,8 +3,9 @@ import logging import os import time +from collections import deque from itertools import chain -from typing import Dict +from typing import AsyncGenerator, Dict, List, Tuple import aiohttp.web import grpc @@ -15,6 +16,7 @@ import ray.dashboard.utils as dashboard_utils from ray import NodeID from ray._private import ray_constants +from ray._private.gcs_pubsub import GcsAioNodeInfoSubscriber from ray._private.ray_constants import DEBUG_AUTOSCALING_ERROR, DEBUG_AUTOSCALING_STATUS from ray._private.utils import get_or_create_event_loop from ray.autoscaler._private.util import ( @@ -33,8 +35,7 @@ from ray.dashboard.datacenter import DataOrganizer, DataSource from ray.dashboard.modules.node import node_consts from ray.dashboard.modules.node.node_consts import ( - FREQUENT_UPDATE_TIMEOUT_SECONDS, - FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS, + RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT, ) from ray.dashboard.utils import async_loop_forever @@ -42,25 +43,22 @@ routes = dashboard_optional_utils.DashboardHeadRouteTable -def gcs_node_info_to_dict(message): +def _gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict: return dashboard_utils.message_to_dict( message, {"nodeId"}, always_print_fields_with_no_presence=True ) -def gcs_stats_to_dict(message): - decode_keys = { - "actorId", - "jobId", - "taskId", - "parentTaskId", - "sourceActorId", - "callerId", - "rayletId", - "workerId", - "placementGroupId", - } - return dashboard_utils.message_to_dict(message, decode_keys) +def _map_batch_node_info_to_dict( + messages: Dict[NodeID, gcs_pb2.GcsNodeInfo] +) -> List[dict]: + return [_gcs_node_info_to_dict(message) for message in messages.values()] + + +def _list_gcs_node_info_to_dict( + messages: List[Tuple[bytes, gcs_pb2.GcsNodeInfo]] +) -> List[dict]: + return [_gcs_node_info_to_dict(node_info) for _, node_info in messages] def node_stats_to_dict(message): @@ -138,20 +136,20 @@ def __init__(self, dashboard_head): self.get_all_node_info = None self._collect_memory_info = False DataSource.nodes.signal.append(self._update_stubs) - # Total number of node updates happened. - self._node_update_cnt = 0 # The time where the module is started. self._module_start_time = time.time() # The time it takes until the head node is registered. None means # head node hasn't been registered. self._head_node_registration_time_s = None + # Queue of dead nodes to be removed, up to MAX_DEAD_NODES_TO_CACHE + self._dead_node_queue = deque() self._gcs_aio_client = dashboard_head.gcs_aio_client self._gcs_address = dashboard_head.gcs_address async def _update_stubs(self, change): if change.old: node_id, node_info = change.old - self._stubs.pop(node_id) + self._stubs.pop(node_id, None) if change.new: # TODO(fyrestone): Handle exceptions. node_id, node_info = change.new @@ -170,104 +168,130 @@ def get_internal_states(self): "head_node_registration_time_s": self._head_node_registration_time_s, "registered_nodes": len(DataSource.nodes), "registered_agents": len(DataSource.agents), - "node_update_count": self._node_update_cnt, "module_lifetime_s": time.time() - self._module_start_time, } - async def _get_nodes(self): - """Read the client table. + async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]: + """ + Yields the initial state of all nodes, then yields the updated state of nodes. - Returns: - A dict of information about the nodes in the cluster. + It makes GetAllNodeInfo call only once after the subscription is done, to get + the initial state of the nodes. """ - try: - nodes = await self.get_all_node_info(timeout=GCS_RPC_TIMEOUT_SECONDS) - return { - node_id.hex(): gcs_node_info_to_dict(node_info) - for node_id, node_info in nodes.items() - } - except Exception: - logger.exception("Failed to GetAllNodeInfo.") - raise + gcs_addr = self._gcs_address + subscriber = GcsAioNodeInfoSubscriber(address=gcs_addr) + await subscriber.subscribe() + + # Get all node info from GCS. To prevent Time-of-check to time-of-use issue [1], + # it happens after the subscription. That is, an update between + # get-all-node-info and the subscription is not missed. + # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use + all_node_info = await self.get_all_node_info(timeout=None) + + all_node_dicts = await get_or_create_event_loop().run_in_executor( + self._dashboard_head._thread_pool_executor, + _map_batch_node_info_to_dict, + all_node_info, + ) + for node in all_node_dicts: + yield node - async def _update_nodes(self): - # TODO(fyrestone): Refactor code for updating actor / node / job. - # Subscribe actor channel. while True: try: - nodes = await self._get_nodes() - - alive_node_ids = [] - alive_node_infos = [] - for node in nodes.values(): - node_id = node["nodeId"] - if node["isHeadNode"] and not self._head_node_registration_time_s: - self._head_node_registration_time_s = ( - time.time() - self._module_start_time - ) - # Put head node ID in the internal KV to be read by JobAgent. - # TODO(architkulkarni): Remove once State API exposes which - # node is the head node. - await self._gcs_aio_client.internal_kv_put( - ray_constants.KV_HEAD_NODE_ID_KEY, - node_id.encode(), - overwrite=True, - namespace=ray_constants.KV_NAMESPACE_JOB, - timeout=GCS_RPC_TIMEOUT_SECONDS, - ) - assert node["state"] in ["ALIVE", "DEAD"] - if node["state"] == "ALIVE": - alive_node_ids.append(node_id) - alive_node_infos.append(node) - - agents = dict(DataSource.agents) - for node_id in alive_node_ids: - # Since the agent fate shares with a raylet, - # the agent port will never change once it is discovered. - if node_id not in agents: - key = ( - f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" - f"{node_id}" - ) - agent_port = await self._gcs_aio_client.internal_kv_get( - key.encode(), - namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - timeout=GCS_RPC_TIMEOUT_SECONDS, - ) - if agent_port: - agents[node_id] = json.loads(agent_port) - for node_id in agents.keys() - set(alive_node_ids): - agents.pop(node_id, None) + published = await subscriber.poll( + batch_size=node_consts.RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE + ) + updated_dicts = await get_or_create_event_loop().run_in_executor( + self._dashboard_head._thread_pool_executor, + _list_gcs_node_info_to_dict, + published, + ) + for node in updated_dicts: + yield node + except Exception: + logger.exception("Failed handling updated nodes.") + + async def _update_node(self, node: dict): + node_id = node["nodeId"] # hex + if node["isHeadNode"] and not self._head_node_registration_time_s: + self._head_node_registration_time_s = time.time() - self._module_start_time + # Put head node ID in the internal KV to be read by JobAgent. + # TODO(architkulkarni): Remove once State API exposes which + # node is the head node. + await self._gcs_aio_client.internal_kv_put( + ray_constants.KV_HEAD_NODE_ID_KEY, + node_id.encode(), + overwrite=True, + namespace=ray_constants.KV_NAMESPACE_JOB, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + assert node["state"] in ["ALIVE", "DEAD"] + is_alive = node["state"] == "ALIVE" + # Prepare agents for alive node, and pop agents for dead node. + if is_alive: + if node_id not in DataSource.agents: + # Agent port is read from internal KV, which is only populated + # upon Agent startup. In case this update received before agent + # fully started up, we schedule a task to asynchronously update + # DataSource with appropriate agent port. + asyncio.create_task(self._update_agent(node_id)) + else: + DataSource.agents.pop(node_id, None) + self._dead_node_queue.append(node_id) + if len(self._dead_node_queue) > node_consts.MAX_DEAD_NODES_TO_CACHE: + DataSource.nodes.pop(self._dead_node_queue.popleft(), None) + DataSource.nodes[node_id] = node - DataSource.agents.reset(agents) - DataSource.nodes.reset(nodes) + async def _update_agent(self, node_id): + """ + Given a node, update the agent_port in DataSource.agents. Problem is it's not + present until agent.py starts, so we need to loop waiting for agent.py writes + its port to internal kv. + """ + key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}".encode() + while True: + try: + agent_port = await self._gcs_aio_client.internal_kv_get( + key, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + timeout=None, + ) + # The node may be dead already. Only update DataSource.agents if the + # node is still alive. + if DataSource.nodes.get(node_id, {}).get("state") != "ALIVE": + return + if agent_port: + DataSource.agents[node_id] = json.loads(agent_port) + return except Exception: - logger.exception("Error updating nodes.") - finally: - self._node_update_cnt += 1 - # _head_node_registration_time_s == None if head node is not - # registered. - head_node_not_registered = not self._head_node_registration_time_s - # Until the head node is registered, we update the - # node status more frequently. - # If the head node is not updated after 10 seconds, it just stops - # doing frequent update to avoid unexpected edge case. + logger.exception(f"Error getting agent port for node {node_id}.") + + await asyncio.sleep(node_consts.RAY_DASHBOARD_AGENT_POLL_INTERVAL_S) + + async def _update_nodes(self): + """ + Subscribe to node updates and update the internal states. If the head node is + not registered after RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT, it logs a + warning only once. + """ + warning_shown = False + async for node in self._subscribe_for_node_updates(): + await self._update_node(node) + if not self._head_node_registration_time_s: + # head node is not registered yet if ( - head_node_not_registered - and self._node_update_cnt * FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS - < FREQUENT_UPDATE_TIMEOUT_SECONDS + not warning_shown + and (time.time() - self._module_start_time) + > RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT ): - await asyncio.sleep(FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS) - else: - if head_node_not_registered: - logger.warning( - "Head node is not registered even after " - f"{FREQUENT_UPDATE_TIMEOUT_SECONDS} seconds. " - "The API server might not work correctly. Please " - "report a Github issue. Internal states :" - f"{self.get_internal_states()}" - ) - await asyncio.sleep(node_consts.UPDATE_NODES_INTERVAL_SECONDS) + logger.warning( + "Head node is not registered even after " + f"{RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT} seconds. " + "The API server might not work correctly. Please " + "report a Github issue. Internal states :" + f"{self.get_internal_states()}" + ) + warning_shown = True @routes.get("/internal/node_module") async def get_node_module_internal_state(self, req) -> aiohttp.web.Response: diff --git a/python/ray/dashboard/modules/node/tests/test_node.py b/python/ray/dashboard/modules/node/tests/test_node.py index 18634d45ac15e..d3f11140e94f1 100644 --- a/python/ray/dashboard/modules/node/tests/test_node.py +++ b/python/ray/dashboard/modules/node/tests/test_node.py @@ -17,7 +17,6 @@ wait_until_server_available, ) from ray.cluster_utils import Cluster -from ray.dashboard.modules.node.node_consts import UPDATE_NODES_INTERVAL_SECONDS from ray.dashboard.tests.conftest import * # noqa logger = logging.getLogger(__name__) @@ -228,31 +227,5 @@ def get_nodes(): time.sleep(2) -@pytest.mark.parametrize( - "ray_start_cluster_head", [{"include_dashboard": True}], indirect=True -) -def test_frequent_node_update( - enable_test_module, disable_aiohttp_cache, ray_start_cluster_head -): - cluster: Cluster = ray_start_cluster_head - assert wait_until_server_available(cluster.webui_url) - webui_url = cluster.webui_url - webui_url = format_web_url(webui_url) - - def verify(): - response = requests.get(webui_url + "/internal/node_module") - response.raise_for_status() - result = response.json() - data = result["data"] - head_node_registration_time = data["headNodeRegistrationTimeS"] - # If the head node is not registered, it is None. - assert head_node_registration_time is not None - # Head node should be registered before the node update interval - # because we do frequent until the head node is registered. - return head_node_registration_time < UPDATE_NODES_INTERVAL_SECONDS - - wait_for_condition(verify, timeout=15) - - if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/dashboard/modules/reporter/reporter_head.py b/python/ray/dashboard/modules/reporter/reporter_head.py index 3bb25e44be05e..9ca0894d3c791 100644 --- a/python/ray/dashboard/modules/reporter/reporter_head.py +++ b/python/ray/dashboard/modules/reporter/reporter_head.py @@ -67,7 +67,7 @@ async def _update_stubs(self, change): if change.old: node_id, port = change.old ip = DataSource.nodes[node_id]["nodeManagerAddress"] - self._stubs.pop(ip) + self._stubs.pop(ip, None) if change.new: node_id, ports = change.new ip = DataSource.nodes[node_id]["nodeManagerAddress"] diff --git a/python/ray/dashboard/tests/conftest.py b/python/ray/dashboard/tests/conftest.py index 29663e422e88c..511276761a037 100644 --- a/python/ray/dashboard/tests/conftest.py +++ b/python/ray/dashboard/tests/conftest.py @@ -3,6 +3,7 @@ import pytest +import ray.dashboard.modules # noqa from ray.tests.conftest import * # noqa