From a967a3550399929098361bb397feb22a414ec628 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang <56065503+rynewang@users.noreply.github.com> Date: Fri, 30 Aug 2024 11:09:38 -0700 Subject: [PATCH] [core][dashboard] TPE for state_aggregator.py (#47392) Uses ThreadedPoolExecutor to handle all proto->json conversions. Semantically unchanged. Mostly this just indents code into a function and call it in TPE. Signed-off-by: Ruiyang Wang --- python/ray/dashboard/state_aggregator.py | 556 ++++++++++++----------- 1 file changed, 303 insertions(+), 253 deletions(-) diff --git a/python/ray/dashboard/state_aggregator.py b/python/ray/dashboard/state_aggregator.py index 3027cf7d7f9d..c4ba56c954c2 100644 --- a/python/ray/dashboard/state_aggregator.py +++ b/python/ray/dashboard/state_aggregator.py @@ -241,32 +241,37 @@ async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - result = [] - for message in reply.actor_table_data: - data = protobuf_message_to_dict( - message=message, - fields_to_decode=[ - "actor_id", - "owner_id", - "job_id", - "node_id", - "placement_group_id", - ], + def transform(reply) -> ListApiResponse: + result = [] + for message in reply.actor_table_data: + data = protobuf_message_to_dict( + message=message, + fields_to_decode=[ + "actor_id", + "owner_id", + "job_id", + "node_id", + "placement_group_id", + ], + ) + result.append(data) + + num_after_truncation = len(result) + reply.num_filtered + result = self._filter(result, option.filters, ActorState, option.detail) + num_filtered = len(result) + + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["actor_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) - result.append(data) - - num_after_truncation = len(result) + reply.num_filtered - result = self._filter(result, option.filters, ActorState, option.detail) - num_filtered = len(result) - - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["actor_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse: @@ -283,26 +288,35 @@ async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiRespo except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - result = [] - for message in reply.placement_group_table_data: - data = protobuf_message_to_dict( - message=message, - fields_to_decode=["placement_group_id", "creator_job_id", "node_id"], + def transform(reply) -> ListApiResponse: + result = [] + for message in reply.placement_group_table_data: + data = protobuf_message_to_dict( + message=message, + fields_to_decode=[ + "placement_group_id", + "creator_job_id", + "node_id", + ], + ) + result.append(data) + num_after_truncation = len(result) + + result = self._filter( + result, option.filters, PlacementGroupState, option.detail + ) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["placement_group_id"]) + return ListApiResponse( + result=list(islice(result, option.limit)), + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) - result.append(data) - num_after_truncation = len(result) - result = self._filter( - result, option.filters, PlacementGroupState, option.detail - ) - num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["placement_group_id"]) - return ListApiResponse( - result=list(islice(result, option.limit)), - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: @@ -319,33 +333,39 @@ async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - result = [] - for message in reply.node_info_list: - data = protobuf_message_to_dict( - message=message, fields_to_decode=["node_id"] - ) - data["node_ip"] = data["node_manager_address"] - data["start_time_ms"] = int(data["start_time_ms"]) - data["end_time_ms"] = int(data["end_time_ms"]) - death_info = data.get("death_info", {}) - data["state_message"] = compose_state_message( - death_info.get("reason", None), death_info.get("reason_message", None) - ) + def transform(reply) -> ListApiResponse: + result = [] + for message in reply.node_info_list: + data = protobuf_message_to_dict( + message=message, fields_to_decode=["node_id"] + ) + data["node_ip"] = data["node_manager_address"] + data["start_time_ms"] = int(data["start_time_ms"]) + data["end_time_ms"] = int(data["end_time_ms"]) + death_info = data.get("death_info", {}) + data["state_message"] = compose_state_message( + death_info.get("reason", None), + death_info.get("reason_message", None), + ) - result.append(data) + result.append(data) - num_after_truncation = len(result) + reply.num_filtered - result = self._filter(result, option.filters, NodeState, option.detail) - num_filtered = len(result) + num_after_truncation = len(result) + reply.num_filtered + result = self._filter(result, option.filters, NodeState, option.detail) + num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["node_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["node_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: @@ -363,49 +383,61 @@ async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - result = [] - for message in reply.worker_table_data: - data = protobuf_message_to_dict( - message=message, fields_to_decode=["worker_id", "raylet_id"] + def transform(reply) -> ListApiResponse: + + result = [] + for message in reply.worker_table_data: + data = protobuf_message_to_dict( + message=message, fields_to_decode=["worker_id", "raylet_id"] + ) + data["worker_id"] = data["worker_address"]["worker_id"] + data["node_id"] = data["worker_address"]["raylet_id"] + data["ip"] = data["worker_address"]["ip_address"] + data["start_time_ms"] = int(data["start_time_ms"]) + data["end_time_ms"] = int(data["end_time_ms"]) + data["worker_launch_time_ms"] = int(data["worker_launch_time_ms"]) + data["worker_launched_time_ms"] = int(data["worker_launched_time_ms"]) + result.append(data) + + num_after_truncation = len(result) + reply.num_filtered + result = self._filter(result, option.filters, WorkerState, option.detail) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["worker_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) - data["worker_id"] = data["worker_address"]["worker_id"] - data["node_id"] = data["worker_address"]["raylet_id"] - data["ip"] = data["worker_address"]["ip_address"] - data["start_time_ms"] = int(data["start_time_ms"]) - data["end_time_ms"] = int(data["end_time_ms"]) - data["worker_launch_time_ms"] = int(data["worker_launch_time_ms"]) - data["worker_launched_time_ms"] = int(data["worker_launched_time_ms"]) - result.append(data) - - num_after_truncation = len(result) + reply.num_filtered - result = self._filter(result, option.filters, WorkerState, option.detail) - num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["worker_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=reply.total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse: try: - result = await self._client.get_job_info(timeout=option.timeout) - result = [job.dict() for job in result] + reply = await self._client.get_job_info(timeout=option.timeout) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + + def transform(reply) -> ListApiResponse: + result = [job.dict() for job in reply] total = len(result) result = self._filter(result, option.filters, JobState, option.detail) num_filtered = len(result) result.sort(key=lambda entry: entry["job_id"] or "") result = list(islice(result, option.limit)) - except DataSourceUnavailable: - raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - return ListApiResponse( - result=result, - total=total, - num_after_truncation=total, - num_filtered=num_filtered, + return ListApiResponse( + result=result, + total=total, + num_after_truncation=total, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: @@ -424,12 +456,10 @@ async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) - def transform(reply): + def transform(reply) -> ListApiResponse: """ Transforms from proto to dict, applies filters, sorts, and truncates. This function is executed in a separate thread. - - Returns the ListApiResponse. """ result = [ protobuf_to_task_state_dict(message) for message in reply.events_by_task @@ -474,85 +504,90 @@ async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: return_exceptions=True, ) - unresponsive_nodes = 0 - worker_stats = [] - total_objects = 0 - for reply, _ in zip(replies, raylet_ids): - if isinstance(reply, DataSourceUnavailable): - unresponsive_nodes += 1 - continue - elif isinstance(reply, Exception): - raise reply - - total_objects += reply.total - for core_worker_stat in reply.core_workers_stats: - # NOTE: Set preserving_proto_field_name=False here because - # `construct_memory_table` requires a dictionary that has - # modified protobuf name - # (e.g., workerId instead of worker_id) as a key. - worker_stats.append( - protobuf_message_to_dict( - message=core_worker_stat, - fields_to_decode=["object_id"], - preserving_proto_field_name=False, + def transform(replies) -> ListApiResponse: + unresponsive_nodes = 0 + worker_stats = [] + total_objects = 0 + for reply, _ in zip(replies, raylet_ids): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + total_objects += reply.total + for core_worker_stat in reply.core_workers_stats: + # NOTE: Set preserving_proto_field_name=False here because + # `construct_memory_table` requires a dictionary that has + # modified protobuf name + # (e.g., workerId instead of worker_id) as a key. + worker_stats.append( + protobuf_message_to_dict( + message=core_worker_stat, + fields_to_decode=["object_id"], + preserving_proto_field_name=False, + ) ) + + partial_failure_warning = None + if len(raylet_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="raylet", + total=len(raylet_ids), + network_failures=unresponsive_nodes, + log_command="raylet.out", + ) + if unresponsive_nodes == len(raylet_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" ) - partial_failure_warning = None - if len(raylet_ids) > 0 and unresponsive_nodes > 0: - warning_msg = NODE_QUERY_FAILURE_WARNING.format( - type="raylet", - total=len(raylet_ids), - network_failures=unresponsive_nodes, - log_command="raylet.out", - ) - if unresponsive_nodes == len(raylet_ids): - raise DataSourceUnavailable(warning_msg) - partial_failure_warning = ( - f"The returned data may contain incomplete result. {warning_msg}" - ) + result = [] + memory_table = memory_utils.construct_memory_table(worker_stats) + for entry in memory_table.table: + data = entry.as_dict() + # `construct_memory_table` returns object_ref field which is indeed + # object_id. We do transformation here. + # TODO(sang): Refactor `construct_memory_table`. + data["object_id"] = data["object_ref"] + del data["object_ref"] + data["ip"] = data["node_ip_address"] + del data["node_ip_address"] + data["type"] = data["type"].upper() + data["task_status"] = ( + "NIL" if data["task_status"] == "-" else data["task_status"] + ) + result.append(data) - result = [] - memory_table = memory_utils.construct_memory_table(worker_stats) - for entry in memory_table.table: - data = entry.as_dict() - # `construct_memory_table` returns object_ref field which is indeed - # object_id. We do transformation here. - # TODO(sang): Refactor `construct_memory_table`. - data["object_id"] = data["object_ref"] - del data["object_ref"] - data["ip"] = data["node_ip_address"] - del data["node_ip_address"] - data["type"] = data["type"].upper() - data["task_status"] = ( - "NIL" if data["task_status"] == "-" else data["task_status"] - ) - result.append(data) - - # Add callsite warnings if it is not configured. - callsite_warning = [] - callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0) - if not callsite_enabled: - callsite_warning.append( - "Callsite is not being recorded. " - "To record callsite information for each ObjectRef created, set " - "env variable RAY_record_ref_creation_sites=1 during `ray start` " - "and `ray.init`." + # Add callsite warnings if it is not configured. + callsite_warning = [] + callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0) + if not callsite_enabled: + callsite_warning.append( + "Callsite is not being recorded. " + "To record callsite information for each ObjectRef created, set " + "env variable RAY_record_ref_creation_sites=1 during `ray start` " + "and `ray.init`." + ) + + num_after_truncation = len(result) + result = self._filter(result, option.filters, ObjectState, option.detail) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["object_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + partial_failure_warning=partial_failure_warning, + total=total_objects, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + warnings=callsite_warning, ) - num_after_truncation = len(result) - result = self._filter(result, option.filters, ObjectState, option.detail) - num_filtered = len(result) - # Sort to make the output deterministic. - result.sort(key=lambda entry: entry["object_id"]) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - partial_failure_warning=partial_failure_warning, - total=total_objects, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, - warnings=callsite_warning, + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, replies ) async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: @@ -574,66 +609,73 @@ async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: return_exceptions=True, ) - result = [] - unresponsive_nodes = 0 - total_runtime_envs = 0 - for node_id, reply in zip( - self._client.get_all_registered_runtime_env_agent_ids(), replies - ): - if isinstance(reply, DataSourceUnavailable): - unresponsive_nodes += 1 - continue - elif isinstance(reply, Exception): - raise reply - - total_runtime_envs += reply.total - states = reply.runtime_env_states - for state in states: - data = protobuf_message_to_dict(message=state, fields_to_decode=[]) - # Need to deserialize this field. - data["runtime_env"] = RuntimeEnv.deserialize( - data["runtime_env"] - ).to_dict() - data["node_id"] = node_id - result.append(data) - - partial_failure_warning = None - if len(agent_ids) > 0 and unresponsive_nodes > 0: - warning_msg = NODE_QUERY_FAILURE_WARNING.format( - type="agent", - total=len(agent_ids), - network_failures=unresponsive_nodes, - log_command="dashboard_agent.log", + def transform(replies) -> ListApiResponse: + result = [] + unresponsive_nodes = 0 + total_runtime_envs = 0 + for node_id, reply in zip( + self._client.get_all_registered_runtime_env_agent_ids(), replies + ): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + total_runtime_envs += reply.total + states = reply.runtime_env_states + for state in states: + data = protobuf_message_to_dict(message=state, fields_to_decode=[]) + # Need to deserialize this field. + data["runtime_env"] = RuntimeEnv.deserialize( + data["runtime_env"] + ).to_dict() + data["node_id"] = node_id + result.append(data) + + partial_failure_warning = None + if len(agent_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="agent", + total=len(agent_ids), + network_failures=unresponsive_nodes, + log_command="dashboard_agent.log", + ) + if unresponsive_nodes == len(agent_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + num_after_truncation = len(result) + result = self._filter( + result, option.filters, RuntimeEnvState, option.detail ) - if unresponsive_nodes == len(agent_ids): - raise DataSourceUnavailable(warning_msg) - partial_failure_warning = ( - f"The returned data may contain incomplete result. {warning_msg}" + num_filtered = len(result) + + # Sort to make the output deterministic. + def sort_func(entry): + # If creation time is not there yet (runtime env is failed + # to be created or not created yet, they are the highest priority. + # Otherwise, "bigger" creation time is coming first. + if "creation_time_ms" not in entry: + return float("inf") + elif entry["creation_time_ms"] is None: + return float("inf") + else: + return float(entry["creation_time_ms"]) + + result.sort(key=sort_func, reverse=True) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + partial_failure_warning=partial_failure_warning, + total=total_runtime_envs, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, ) - num_after_truncation = len(result) - result = self._filter(result, option.filters, RuntimeEnvState, option.detail) - num_filtered = len(result) - - # Sort to make the output deterministic. - def sort_func(entry): - # If creation time is not there yet (runtime env is failed - # to be created or not created yet, they are the highest priority. - # Otherwise, "bigger" creation time is coming first. - if "creation_time_ms" not in entry: - return float("inf") - elif entry["creation_time_ms"] is None: - return float("inf") - else: - return float(entry["creation_time_ms"]) - - result.sort(key=sort_func, reverse=True) - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - partial_failure_warning=partial_failure_warning, - total=total_runtime_envs, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, replies ) async def list_cluster_events(self, *, option: ListApiOptions) -> ListApiResponse: @@ -644,25 +686,33 @@ async def list_cluster_events(self, *, option: ListApiOptions) -> ListApiRespons The schema of returned "dict" is equivalent to the `ClusterEventState` protobuf message. """ - result = [] - all_events = await self._client.get_all_cluster_events() - for _, events in all_events.items(): - for _, event in events.items(): - event["time"] = str(datetime.fromtimestamp(int(event["timestamp"]))) - result.append(event) - - num_after_truncation = len(result) - result.sort(key=lambda entry: entry["timestamp"]) - total = len(result) - result = self._filter(result, option.filters, ClusterEventState, option.detail) - num_filtered = len(result) - # Sort to make the output deterministic. - result = list(islice(result, option.limit)) - return ListApiResponse( - result=result, - total=total, - num_after_truncation=num_after_truncation, - num_filtered=num_filtered, + reply = await self._client.get_all_cluster_events() + + def transform(reply) -> ListApiResponse: + result = [] + for _, events in reply.items(): + for _, event in events.items(): + event["time"] = str(datetime.fromtimestamp(int(event["timestamp"]))) + result.append(event) + + num_after_truncation = len(result) + result.sort(key=lambda entry: entry["timestamp"]) + total = len(result) + result = self._filter( + result, option.filters, ClusterEventState, option.detail + ) + num_filtered = len(result) + # Sort to make the output deterministic. + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply ) async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse: