From d235f8e6963fd14cbd2f77ba28c1296331ea0dec Mon Sep 17 00:00:00 2001 From: kaihsun Date: Sun, 24 Nov 2024 10:18:46 +0000 Subject: [PATCH] update Signed-off-by: kaihsun --- .../cloud_providers/kuberay/cloud_provider.py | 29 +++++++++++---- .../autoscaler/v2/tests/test_node_provider.py | 37 +++++++++++++++++++ 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py b/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py index 19e236cb4d19..500bcd6a7188 100644 --- a/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py +++ b/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py @@ -209,21 +209,22 @@ def _initialize_scale_request( cur_instances = self.instances # Get the worker groups that have pending deletes and the worker groups that - # have finished deletes. + # have finished deletes, and the set of workers included in the workersToDelete + # field of any worker group. ( worker_groups_with_pending_deletes, worker_groups_without_pending_deletes, + worker_to_delete_set, ) = self._get_workers_groups_with_deletes( ray_cluster, set(cur_instances.keys()) ) # Calculate the desired number of workers by type. num_workers_dict = defaultdict(int) - for _, cur_instance in cur_instances.items(): - if cur_instance.node_kind == NodeKind.HEAD: - # Only track workers. - continue - num_workers_dict[cur_instance.node_type] += 1 + worker_groups = ray_cluster["spec"].get("workerGroupSpecs", []) + for worker_group in worker_groups: + node_type = worker_group["groupName"] + num_workers_dict[node_type] = worker_group["replicas"] # Add to launch nodes. for node_type, count in to_launch.items(): @@ -242,6 +243,11 @@ def _initialize_scale_request( # Not possible to delete head node. continue + if to_delete_instance.cloud_instance_id in worker_to_delete_set: + # If the instance is already in the workersToDelete field of + # any worker group, skip it. + continue + num_workers_dict[to_delete_instance.node_type] -= 1 assert num_workers_dict[to_delete_instance.node_type] >= 0 to_delete_instances_by_type[to_delete_instance.node_type].append( @@ -321,6 +327,7 @@ def _submit_scale_request( # No patch required. return + logger.info(f"Submitting a scale request: {scale_request}") self._patch(f"rayclusters/{self._cluster_name}", patch_payload) def _add_launch_errors( @@ -404,10 +411,13 @@ def _get_workers_groups_with_deletes( deletes. worker_groups_with_finished_deletes: The worker groups that have finished deletes. + worker_to_delete_set: A set of Pods that are included in the workersToDelete + field of any worker group. """ worker_groups_with_pending_deletes = set() worker_groups_with_deletes = set() + worker_to_delete_set = set() worker_groups = ray_cluster_spec["spec"].get("workerGroupSpecs", []) for worker_group in worker_groups: @@ -422,6 +432,7 @@ def _get_workers_groups_with_deletes( worker_groups_with_deletes.add(node_type) for worker in workersToDelete: + worker_to_delete_set.add(worker) if worker in node_set: worker_groups_with_pending_deletes.add(node_type) break @@ -429,7 +440,11 @@ def _get_workers_groups_with_deletes( worker_groups_with_finished_deletes = ( worker_groups_with_deletes - worker_groups_with_pending_deletes ) - return worker_groups_with_pending_deletes, worker_groups_with_finished_deletes + return ( + worker_groups_with_pending_deletes, + worker_groups_with_finished_deletes, + worker_to_delete_set, + ) def _fetch_instances(self) -> Dict[CloudInstanceId, CloudInstance]: """ diff --git a/python/ray/autoscaler/v2/tests/test_node_provider.py b/python/ray/autoscaler/v2/tests/test_node_provider.py index 02d84e376b8d..a8e010f58f65 100644 --- a/python/ray/autoscaler/v2/tests/test_node_provider.py +++ b/python/ray/autoscaler/v2/tests/test_node_provider.py @@ -492,6 +492,43 @@ def test_pending_deletes(self): }, ] + def test_inconsistent_pods_raycr(self): + """ + Test the case where the cluster state has not yet reached the desired state. + Specifically, the replicas field in the RayCluster CR does not match the actual + number of Pods. + """ + # Check the assumptions of the test + small_group = "small-group" + num_pods = 0 + for pod in self.mock_client._pod_list["items"]: + if pod["metadata"]["labels"]["ray.io/group"] == small_group: + num_pods += 1 + + assert ( + self.mock_client._ray_cluster["spec"]["workerGroupSpecs"][0]["groupName"] + == small_group + ) + desired_replicas = num_pods + 1 + self.mock_client._ray_cluster["spec"]["workerGroupSpecs"][0][ + "replicas" + ] = desired_replicas + + # Launch a new node. The replicas field should be incremented by 1, even though + # the cluster state has not yet reached the goal state. + launch_request = {"small-group": 1} + self.provider.launch(shape=launch_request, request_id="launch-1") + + patches = self.mock_client.get_patches( + f"rayclusters/{self.provider._cluster_name}" + ) + assert len(patches) == 1 + assert patches[0] == { + "op": "replace", + "path": "/spec/workerGroupSpecs/0/replicas", + "value": desired_replicas + 1, + } + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"):