Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: kaihsun <kaihsun@anyscale.com>
  • Loading branch information
kevin85421 committed Nov 24, 2024
1 parent 4a29571 commit d235f8e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,21 +209,22 @@ def _initialize_scale_request(
cur_instances = self.instances

# Get the worker groups that have pending deletes and the worker groups that
# have finished deletes.
# have finished deletes, and the set of workers included in the workersToDelete
# field of any worker group.
(
worker_groups_with_pending_deletes,
worker_groups_without_pending_deletes,
worker_to_delete_set,
) = self._get_workers_groups_with_deletes(
ray_cluster, set(cur_instances.keys())
)

# Calculate the desired number of workers by type.
num_workers_dict = defaultdict(int)
for _, cur_instance in cur_instances.items():
if cur_instance.node_kind == NodeKind.HEAD:
# Only track workers.
continue
num_workers_dict[cur_instance.node_type] += 1
worker_groups = ray_cluster["spec"].get("workerGroupSpecs", [])
for worker_group in worker_groups:
node_type = worker_group["groupName"]
num_workers_dict[node_type] = worker_group["replicas"]

# Add to launch nodes.
for node_type, count in to_launch.items():
Expand All @@ -242,6 +243,11 @@ def _initialize_scale_request(
# Not possible to delete head node.
continue

if to_delete_instance.cloud_instance_id in worker_to_delete_set:
# If the instance is already in the workersToDelete field of
# any worker group, skip it.
continue

num_workers_dict[to_delete_instance.node_type] -= 1
assert num_workers_dict[to_delete_instance.node_type] >= 0
to_delete_instances_by_type[to_delete_instance.node_type].append(
Expand Down Expand Up @@ -321,6 +327,7 @@ def _submit_scale_request(
# No patch required.
return

logger.info(f"Submitting a scale request: {scale_request}")
self._patch(f"rayclusters/{self._cluster_name}", patch_payload)

def _add_launch_errors(
Expand Down Expand Up @@ -404,10 +411,13 @@ def _get_workers_groups_with_deletes(
deletes.
worker_groups_with_finished_deletes: The worker groups that have finished
deletes.
worker_to_delete_set: A set of Pods that are included in the workersToDelete
field of any worker group.
"""

worker_groups_with_pending_deletes = set()
worker_groups_with_deletes = set()
worker_to_delete_set = set()

worker_groups = ray_cluster_spec["spec"].get("workerGroupSpecs", [])
for worker_group in worker_groups:
Expand All @@ -422,14 +432,19 @@ def _get_workers_groups_with_deletes(
worker_groups_with_deletes.add(node_type)

for worker in workersToDelete:
worker_to_delete_set.add(worker)
if worker in node_set:
worker_groups_with_pending_deletes.add(node_type)
break

worker_groups_with_finished_deletes = (
worker_groups_with_deletes - worker_groups_with_pending_deletes
)
return worker_groups_with_pending_deletes, worker_groups_with_finished_deletes
return (
worker_groups_with_pending_deletes,
worker_groups_with_finished_deletes,
worker_to_delete_set,
)

def _fetch_instances(self) -> Dict[CloudInstanceId, CloudInstance]:
"""
Expand Down
37 changes: 37 additions & 0 deletions python/ray/autoscaler/v2/tests/test_node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,43 @@ def test_pending_deletes(self):
},
]

def test_inconsistent_pods_raycr(self):
"""
Test the case where the cluster state has not yet reached the desired state.
Specifically, the replicas field in the RayCluster CR does not match the actual
number of Pods.
"""
# Check the assumptions of the test
small_group = "small-group"
num_pods = 0
for pod in self.mock_client._pod_list["items"]:
if pod["metadata"]["labels"]["ray.io/group"] == small_group:
num_pods += 1

assert (
self.mock_client._ray_cluster["spec"]["workerGroupSpecs"][0]["groupName"]
== small_group
)
desired_replicas = num_pods + 1
self.mock_client._ray_cluster["spec"]["workerGroupSpecs"][0][
"replicas"
] = desired_replicas

# Launch a new node. The replicas field should be incremented by 1, even though
# the cluster state has not yet reached the goal state.
launch_request = {"small-group": 1}
self.provider.launch(shape=launch_request, request_id="launch-1")

patches = self.mock_client.get_patches(
f"rayclusters/{self.provider._cluster_name}"
)
assert len(patches) == 1
assert patches[0] == {
"op": "replace",
"path": "/spec/workerGroupSpecs/0/replicas",
"value": desired_replicas + 1,
}


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
Expand Down

0 comments on commit d235f8e

Please sign in to comment.