diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 91f8cd8302c14..99eed6ff33dd2 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -209,6 +209,14 @@ py_test( deps = [":serve_lib"], ) +py_test( + name = "test_deployment_scheduler", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive", "team:serve"], + deps = [":serve_lib"], +) + py_test( name = "test_deployment_version", size = "small", diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py new file mode 100644 index 0000000000000..644c81f90de19 --- /dev/null +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -0,0 +1,275 @@ +from typing import Callable, Dict, Tuple, List, Union, Set +from dataclasses import dataclass +from collections import defaultdict + +import ray +from ray._raylet import GcsClient +from ray.serve._private.utils import get_all_node_ids +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + +class SpreadDeploymentSchedulingPolicy: + """A scheduling policy that spreads replicas with best effort.""" + + pass + + +class DriverDeploymentSchedulingPolicy: + """A scheduling policy that schedules exactly one replica on each node.""" + + pass + + +@dataclass +class ReplicaSchedulingRequest: + """Request to schedule a single replica. + + The scheduler is responsible for scheduling + based on the deployment scheduling policy. + """ + + deployment_name: str + replica_name: str + actor_def: ray.actor.ActorClass + actor_resources: Dict + actor_options: Dict + actor_init_args: Tuple + on_scheduled: Callable + + +@dataclass +class DeploymentDownscaleRequest: + """Request to stop a certain number of replicas. + + The scheduler is responsible for + choosing the replicas to stop. + """ + + deployment_name: str + num_to_stop: int + + +class DeploymentScheduler: + """A centralized scheduler for all Serve deployments. + + It makes a batch of scheduling decisions in each update cycle. + """ + + def __init__(self): + # {deployment_name: scheduling_policy} + self._deployments = {} + # Replicas that are waiting to be scheduled. + # {deployment_name: {replica_name: deployment_upscale_request}} + self._pending_replicas = defaultdict(dict) + # Replicas that are being scheduled. + # The underlying actors have been submitted. + # {deployment_name: {replica_name: target_node_id}} + self._launching_replicas = defaultdict(dict) + # Replicas that are recovering. + # We don't know where those replicas are running. + # {deployment_name: {replica_name}} + self._recovering_replicas = defaultdict(set) + # Replicas that are running. + # We know where those replicas are running. + # {deployment_name: {replica_name: running_node_id}} + self._running_replicas = defaultdict(dict) + + self._gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) + + def on_deployment_created( + self, + deployment_name: str, + scheduling_policy: Union[ + SpreadDeploymentSchedulingPolicy, DriverDeploymentSchedulingPolicy + ], + ) -> None: + """Called whenever a new deployment is created.""" + assert deployment_name not in self._pending_replicas + assert deployment_name not in self._launching_replicas + assert deployment_name not in self._recovering_replicas + assert deployment_name not in self._running_replicas + self._deployments[deployment_name] = scheduling_policy + + def on_deployment_deleted(self, deployment_name: str) -> None: + """Called whenever a deployment is deleted.""" + assert not self._pending_replicas[deployment_name] + self._pending_replicas.pop(deployment_name, None) + + assert not self._launching_replicas[deployment_name] + self._launching_replicas.pop(deployment_name, None) + + assert not self._recovering_replicas[deployment_name] + self._recovering_replicas.pop(deployment_name, None) + + assert not self._running_replicas[deployment_name] + self._running_replicas.pop(deployment_name, None) + + del self._deployments[deployment_name] + + def on_replica_stopping(self, deployment_name: str, replica_name: str) -> None: + """Called whenever a deployment replica is being stopped.""" + self._pending_replicas[deployment_name].pop(replica_name, None) + self._launching_replicas[deployment_name].pop(replica_name, None) + self._recovering_replicas[deployment_name].discard(replica_name) + self._running_replicas[deployment_name].pop(replica_name, None) + + def on_replica_running( + self, deployment_name: str, replica_name: str, node_id: str + ) -> None: + """Called whenever a deployment replica is running with a known node id.""" + assert replica_name not in self._pending_replicas[deployment_name] + + self._launching_replicas[deployment_name].pop(replica_name, None) + self._recovering_replicas[deployment_name].discard(replica_name) + + self._running_replicas[deployment_name][replica_name] = node_id + + def on_replica_recovering(self, deployment_name: str, replica_name: str) -> None: + """Called whenever a deployment replica is recovering.""" + assert replica_name not in self._pending_replicas[deployment_name] + assert replica_name not in self._launching_replicas[deployment_name] + assert replica_name not in self._running_replicas[deployment_name] + assert replica_name not in self._recovering_replicas[deployment_name] + + self._recovering_replicas[deployment_name].add(replica_name) + + def schedule( + self, + upscales: Dict[str, List[ReplicaSchedulingRequest]], + downscales: Dict[str, DeploymentDownscaleRequest], + ) -> Dict[str, Set[str]]: + """Called for each update cycle to do batch scheduling. + + Args: + upscales: a dict of deployment name to a list of replicas to schedule. + downscales: a dict of deployment name to a downscale request. + + Returns: + The name of replicas to stop for each deployment. + """ + for upscale in upscales.values(): + for replica_scheduling_request in upscale: + self._pending_replicas[replica_scheduling_request.deployment_name][ + replica_scheduling_request.replica_name + ] = replica_scheduling_request + + for deployment_name, pending_replicas in self._pending_replicas.items(): + if not pending_replicas: + continue + + deployment_scheduling_policy = self._deployments[deployment_name] + if isinstance( + deployment_scheduling_policy, SpreadDeploymentSchedulingPolicy + ): + self._schedule_spread_deployment(deployment_name) + else: + assert isinstance( + deployment_scheduling_policy, DriverDeploymentSchedulingPolicy + ) + self._schedule_driver_deployment(deployment_name) + + deployment_to_replicas_to_stop = {} + for downscale in downscales.values(): + deployment_to_replicas_to_stop[ + downscale.deployment_name + ] = self._get_replicas_to_stop( + downscale.deployment_name, downscale.num_to_stop + ) + + return deployment_to_replicas_to_stop + + def _schedule_spread_deployment(self, deployment_name: str) -> None: + for pending_replica_name in list( + self._pending_replicas[deployment_name].keys() + ): + replica_scheduling_request = self._pending_replicas[deployment_name][ + pending_replica_name + ] + + actor_handle = replica_scheduling_request.actor_def.options( + scheduling_strategy="SPREAD", + **replica_scheduling_request.actor_options, + ).remote(*replica_scheduling_request.actor_init_args) + del self._pending_replicas[deployment_name][pending_replica_name] + self._launching_replicas[deployment_name][pending_replica_name] = None + replica_scheduling_request.on_scheduled(actor_handle) + + def _schedule_driver_deployment(self, deployment_name: str) -> None: + if self._recovering_replicas[deployment_name]: + # Wait until recovering is done before scheduling new replicas + # so that we can make sure we don't schedule two replicas on the same node. + return + + all_nodes = {node_id for node_id, _ in get_all_node_ids(self._gcs_client)} + scheduled_nodes = set() + for node_id in self._launching_replicas[deployment_name].values(): + assert node_id is not None + scheduled_nodes.add(node_id) + for node_id in self._running_replicas[deployment_name].values(): + assert node_id is not None + scheduled_nodes.add(node_id) + unscheduled_nodes = all_nodes - scheduled_nodes + + for pending_replica_name in list( + self._pending_replicas[deployment_name].keys() + ): + if not unscheduled_nodes: + return + + replica_scheduling_request = self._pending_replicas[deployment_name][ + pending_replica_name + ] + + target_node_id = unscheduled_nodes.pop() + actor_handle = replica_scheduling_request.actor_def.options( + scheduling_strategy=NodeAffinitySchedulingStrategy( + target_node_id, soft=False + ), + **replica_scheduling_request.actor_options, + ).remote(*replica_scheduling_request.actor_init_args) + del self._pending_replicas[deployment_name][pending_replica_name] + self._launching_replicas[deployment_name][ + pending_replica_name + ] = target_node_id + replica_scheduling_request.on_scheduled(actor_handle) + + def _get_replicas_to_stop( + self, deployment_name: str, max_num_to_stop: int + ) -> Set[str]: + """Prioritize replicas that have fewest copies on a node. + + This algorithm helps to scale down more intelligently because it can + relinquish nodes faster. Note that this algorithm doesn't consider other + deployments or other actors on the same node. See more at + https://github.com/ray-project/ray/issues/20599. + """ + replicas_to_stop = set() + + # Replicas not in running state don't have node id. + # We will prioritize those first. + pending_launching_recovering_replicas = set().union( + self._pending_replicas[deployment_name].keys(), + self._launching_replicas[deployment_name].keys(), + self._recovering_replicas[deployment_name], + ) + for ( + pending_launching_recovering_replica + ) in pending_launching_recovering_replicas: + if len(replicas_to_stop) == max_num_to_stop: + return replicas_to_stop + else: + replicas_to_stop.add(pending_launching_recovering_replica) + + node_to_running_replicas = defaultdict(set) + for running_replica, node_id in self._running_replicas[deployment_name].items(): + node_to_running_replicas[node_id].add(running_replica) + for running_replicas in sorted( + node_to_running_replicas.values(), key=lambda lst: len(lst) + ): + for running_replica in running_replicas: + if len(replicas_to_stop) == max_num_to_stop: + return replicas_to_stop + else: + replicas_to_stop.add(running_replica) + + return replicas_to_stop diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 0f4b6c9cd27a8..603c30957f611 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -import itertools import json import logging import math @@ -10,7 +9,7 @@ from collections import defaultdict, OrderedDict from copy import copy from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple import ray from ray import ObjectRef, cloudpickle @@ -59,10 +58,16 @@ check_obj_ref_ready_nowait, ) from ray.serve._private.version import DeploymentVersion, VersionedReplica +from ray.serve._private import deployment_scheduler +from ray.serve._private.deployment_scheduler import ( + SpreadDeploymentSchedulingPolicy, + DriverDeploymentSchedulingPolicy, + ReplicaSchedulingRequest, + DeploymentDownscaleRequest, +) from ray.serve import metrics from ray._raylet import GcsClient -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -115,6 +120,14 @@ def from_deployment_info( return cls(info, num_replicas, version, deleting) +@dataclass +class DeploymentStateUpdateResult: + deleted: bool + any_replicas_recovering: bool + upscale: List[ReplicaSchedulingRequest] + downscale: Optional[DeploymentDownscaleRequest] + + CHECKPOINT_KEY = "serve-deployment-state-checkpoint" SLOW_STARTUP_WARNING_S = int(os.environ.get("SERVE_SLOW_STARTUP_WARNING_S", 30)) SLOW_STARTUP_WARNING_PERIOD_S = int( @@ -147,34 +160,6 @@ def print_verbose_scaling_log(): logger.error(f"Scaling information\n{json.dumps(debug_info, indent=2)}") -def rank_replicas_for_stopping( - all_available_replicas: List["DeploymentReplica"], -) -> List["DeploymentReplica"]: - """Prioritize replicas that have fewest copies on a node. - - This algorithm helps to scale down more intelligently because it can - relinquish node faster. Note that this algorithm doesn't consider other - deployments or other actors on the same node. See more at - https://github.com/ray-project/ray/issues/20599. - """ - # Categorize replicas to node they belong to. - node_to_replicas = defaultdict(list) - for replica in all_available_replicas: - node_to_replicas[replica.actor_node_id].append(replica) - - # Replicas not in running state might have _node_id = None. - # We will prioritize those first. - node_to_replicas.setdefault(None, []) - return list( - itertools.chain.from_iterable( - [ - node_to_replicas.pop(None), - ] - + sorted(node_to_replicas.values(), key=lambda lst: len(lst)) - ) - ) - - class ActorReplicaWrapper: """Wraps a Ray actor for a deployment replica. @@ -192,10 +177,6 @@ def __init__( replica_tag: ReplicaTag, deployment_name: str, version: DeploymentVersion, - # Spread replicas to avoid correlated failures on a single node. - # This is a soft spread, so if there is only space on a single node - # the replicas will be placed there. - scheduling_strategy: Union[str, NodeAffinitySchedulingStrategy] = "SPREAD", ): self._actor_name = actor_name self._detached = detached @@ -226,11 +207,8 @@ def __init__( self._pid: int = None self._actor_id: str = None self._worker_id: str = None - if isinstance(scheduling_strategy, NodeAffinitySchedulingStrategy): - self._node_id = scheduling_strategy.node_id - else: - # Populated after replica is allocated. - self._node_id: str = None + # Populated after replica is allocated. + self._node_id: str = None self._node_ip: str = None self._log_file_path: str = None @@ -241,8 +219,6 @@ def __init__( self._is_cross_language = False self._deployment_is_cross_language = False - self.scheduling_strategy = scheduling_strategy - @property def replica_tag(self) -> str: return self._replica_tag @@ -339,9 +315,11 @@ def log_file_path(self) -> Optional[str]: """Returns the relative log file path of the actor, None if not placed.""" return self._log_file_path - def start(self, deployment_info: DeploymentInfo): - """ - Start a new actor for current DeploymentReplica instance. + def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest: + """Start the current DeploymentReplica instance. + + The replica will be in the STARTING and PENDING_ALLOCATION states + until the deployment scheduler schedules the underlying actor. """ self._actor_resources = deployment_info.replica_config.resource_dict # it is currently not possible to create a placement group @@ -421,17 +399,28 @@ def start(self, deployment_info: DeploymentInfo): self._controller_name, ) - self._actor_handle = actor_def.options( - name=self._actor_name, - namespace=SERVE_NAMESPACE, - lifetime="detached" if self._detached else None, - scheduling_strategy=self.scheduling_strategy, - **deployment_info.replica_config.ray_actor_options, - ).remote(*init_args) + actor_options = { + "name": self._actor_name, + "namespace": SERVE_NAMESPACE, + "lifetime": "detached" if self._detached else None, + } + actor_options.update(deployment_info.replica_config.ray_actor_options) + + return ReplicaSchedulingRequest( + deployment_name=self.deployment_name, + replica_name=self.replica_tag, + actor_def=actor_def, + actor_resources=self._actor_resources, + actor_options=actor_options, + actor_init_args=init_args, + on_scheduled=self.on_scheduled, + ) + def on_scheduled(self, actor_handle: ActorHandle): + self._actor_handle = actor_handle # Perform auto method name translation for java handles. # See https://github.com/ray-project/ray/issues/21474 - deployment_config = copy(deployment_info.deployment_config) + deployment_config = copy(self._version.deployment_config) deployment_config.user_config = self._format_user_config( deployment_config.user_config ) @@ -533,9 +522,28 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]: """ # Check whether the replica has been allocated. - if not check_obj_ref_ready_nowait(self._allocated_obj_ref): + if self._allocated_obj_ref is None or not check_obj_ref_ready_nowait( + self._allocated_obj_ref + ): return ReplicaStartupStatus.PENDING_ALLOCATION, None + if not self._is_cross_language: + try: + ( + self._pid, + self._actor_id, + self._worker_id, + self._node_id, + self._node_ip, + self._log_file_path, + ) = ray.get(self._allocated_obj_ref) + except RayTaskError as e: + logger.exception( + f"Exception in replica '{self._replica_tag}', " + "the replica will be stopped." + ) + return ReplicaStartupStatus.FAILED, str(e.as_instanceof_cause()) + # Check whether relica initialization has completed. replica_ready = check_obj_ref_ready_nowait(self._ready_obj_ref) # In case of deployment constructor failure, ray.get will help to @@ -555,15 +563,6 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]: # If this is checking on a replica that is newly started, this # should return a version that is identical to what's already stored _, self._version = ray.get(self._ready_obj_ref) - - ( - self._pid, - self._actor_id, - self._worker_id, - self._node_id, - self._node_ip, - self._log_file_path, - ) = ray.get(self._allocated_obj_ref) except RayTaskError as e: logger.exception( f"Exception in replica '{self._replica_tag}', " @@ -767,10 +766,6 @@ def __init__( replica_tag: ReplicaTag, deployment_name: str, version: DeploymentVersion, - # Spread replicas to avoid correlated failures on a single node. - # This is a soft spread, so if there is only space on a single node - # the replicas will be placed there. - scheduling_strategy: Union[str, NodeAffinitySchedulingStrategy] = "SPREAD", ): self._actor = ActorReplicaWrapper( f"{ReplicaName.prefix}{format_actor_name(replica_tag)}", @@ -779,7 +774,6 @@ def __init__( replica_tag, deployment_name, version, - scheduling_strategy, ) self._controller_name = controller_name self._deployment_name = deployment_name @@ -837,14 +831,15 @@ def actor_node_id(self) -> Optional[str]: """Returns the node id of the actor, None if not placed.""" return self._actor.node_id - def start(self, deployment_info: DeploymentInfo): + def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest: """ Start a new actor for current DeploymentReplica instance. """ - self._actor.start(deployment_info) + replica_scheduling_request = self._actor.start(deployment_info) self._start_time = time.time() self._prev_slow_startup_warning_time = time.time() self.update_actor_details(start_time_s=self._start_time) + return replica_scheduling_request def reconfigure(self, version: DeploymentVersion) -> bool: """ @@ -1108,6 +1103,7 @@ def __init__( controller_name: str, detached: bool, long_poll_host: LongPollHost, + deployment_scheduler: deployment_scheduler.DeploymentScheduler, _save_checkpoint_func: Callable, ): @@ -1115,6 +1111,7 @@ def __init__( self._controller_name: str = controller_name self._detached: bool = detached self._long_poll_host: LongPollHost = long_poll_host + self._deployment_scheduler = deployment_scheduler self._save_checkpoint_func = _save_checkpoint_func # Each time we set a new deployment goal, we're trying to save new @@ -1144,6 +1141,8 @@ def __init__( # time we checked. self._multiplexed_model_ids_updated = False + self._last_notified_running_replica_infos: List[RunningReplicaInfo] = [] + def should_autoscale(self) -> bool: """ Check if the deployment is under autoscaling @@ -1197,6 +1196,9 @@ def recover_current_state_from_replica_actor_names( ) new_deployment_replica.recover() self._replicas.add(ReplicaState.RECOVERING, new_deployment_replica) + self._deployment_scheduler.on_replica_recovering( + replica_name.deployment_tag, replica_name.replica_tag + ) logger.debug( f"RECOVERING replica: {new_deployment_replica.replica_tag}, " f"deployment: {self._name}." @@ -1248,11 +1250,20 @@ def get_active_node_ids(self) -> Set[str]: def list_replica_details(self) -> List[ReplicaDetails]: return [replica.actor_details for replica in self._replicas.get()] - def _notify_running_replicas_changed(self): + def notify_running_replicas_changed(self) -> None: + running_replica_infos = self.get_running_replica_infos() + if ( + set(self._last_notified_running_replica_infos) == set(running_replica_infos) + and not self._multiplexed_model_ids_updated + ): + return + self._long_poll_host.notify_changed( (LongPollNamespace.RUNNING_REPLICAS, self._name), - self.get_running_replica_infos(), + running_replica_infos, ) + self._last_notified_running_replica_infos = running_replica_infos + self._multiplexed_model_ids_updated = False def _set_target_state_deleting(self) -> None: """Set the target state for the deployment to be deleted.""" @@ -1416,7 +1427,6 @@ def _stop_or_update_outdated_version_replicas(self, max_to_stop=math.inf) -> boo replicas_to_update = self._replicas.pop( exclude_version=self._target_state.version, states=[ReplicaState.STARTING, ReplicaState.RUNNING], - ranking_function=rank_replicas_for_stopping, ) replicas_changed = False code_version_changes = 0 @@ -1521,14 +1531,19 @@ def _check_and_stop_wrong_version_replicas(self) -> bool: return self._stop_or_update_outdated_version_replicas(max_to_stop) - def _scale_deployment_replicas(self) -> bool: + def _scale_deployment_replicas( + self, + ) -> Tuple[List[ReplicaSchedulingRequest], DeploymentDownscaleRequest]: """Scale the given deployment to the number of replicas.""" assert ( self._target_state.num_replicas >= 0 ), "Number of replicas must be greater than or equal to 0." - replicas_changed = self._check_and_stop_wrong_version_replicas() + upscale = [] + downscale = None + + self._check_and_stop_wrong_version_replicas() current_replicas = self._replicas.count( states=[ReplicaState.STARTING, ReplicaState.UPDATING, ReplicaState.RUNNING] @@ -1539,7 +1554,7 @@ def _scale_deployment_replicas(self) -> bool: self._target_state.num_replicas - current_replicas - recovering_replicas ) if delta_replicas == 0: - return replicas_changed + return (upscale, downscale) elif delta_replicas > 0: # Don't ever exceed self._target_state.num_replicas. @@ -1562,7 +1577,7 @@ def _scale_deployment_replicas(self) -> bool: time.time() - self._last_retry < self._backoff_time_s + random.uniform(0, 3) ): - return replicas_changed + return upscale, downscale self._last_retry = time.time() logger.info( @@ -1578,7 +1593,9 @@ def _scale_deployment_replicas(self) -> bool: replica_name.deployment_tag, self._target_state.version, ) - new_deployment_replica.start(self._target_state.info) + upscale.append( + new_deployment_replica.start(self._target_state.info) + ) self._replicas.add(ReplicaState.STARTING, new_deployment_replica) logger.debug( @@ -1587,31 +1604,16 @@ def _scale_deployment_replicas(self) -> bool: ) elif delta_replicas < 0: - replicas_changed = True to_remove = -delta_replicas logger.info( f"Removing {to_remove} replica{'s' if to_remove > 1 else ''} " f"from deployment '{self._name}'." ) - replicas_to_stop = self._replicas.pop( - states=[ - ReplicaState.STARTING, - ReplicaState.UPDATING, - ReplicaState.RECOVERING, - ReplicaState.RUNNING, - ], - max_replicas=to_remove, - ranking_function=rank_replicas_for_stopping, + downscale = DeploymentDownscaleRequest( + deployment_name=self._name, num_to_stop=to_remove ) - for replica in replicas_to_stop: - logger.debug( - f"Adding STOPPING to replica_tag: {replica}, " - f"deployment_name: {self._name}" - ) - self._stop_replica(replica) - - return replicas_changed + return upscale, downscale def _check_curr_status(self) -> Tuple[bool, bool]: """Check the current deployment status. @@ -1690,7 +1692,10 @@ def _check_curr_status(self) -> Tuple[bool, bool]: return True, any_replicas_recovering # Check for a non-zero number of deployments. - if target_replica_count == running_at_target_version_replica_cnt: + if ( + target_replica_count == running_at_target_version_replica_cnt + and running_at_target_version_replica_cnt == all_running_replica_cnt + ): self._curr_status_info = DeploymentStatusInfo( self._name, DeploymentStatus.HEALTHY ) @@ -1700,7 +1705,7 @@ def _check_curr_status(self) -> Tuple[bool, bool]: def _check_startup_replicas( self, original_state: ReplicaState, stop_on_slow=False - ) -> Tuple[List[Tuple[DeploymentReplica, ReplicaStartupStatus]], bool]: + ) -> List[Tuple[DeploymentReplica, ReplicaStartupStatus]]: """ Common helper function for startup actions tracking and status transition: STARTING, UPDATING and RECOVERING. @@ -1710,7 +1715,6 @@ def _check_startup_replicas( slow to reach running state. """ slow_replicas = [] - transitioned_to_running = False replicas_failed = False for replica in self._replicas.pop(states=[original_state]): start_status, error_msg = replica.check_started() @@ -1718,7 +1722,9 @@ def _check_startup_replicas( # This replica should be now be added to handle's replica # set. self._replicas.add(ReplicaState.RUNNING, replica) - transitioned_to_running = True + self._deployment_scheduler.on_replica_running( + self._name, replica.replica_tag, replica.actor_node_id + ) logger.info( f"Replica {replica.replica_tag} started successfully " f"on node {replica.actor_node_id}.", @@ -1737,7 +1743,10 @@ def _check_startup_replicas( ReplicaStartupStatus.PENDING_ALLOCATION, ReplicaStartupStatus.PENDING_INITIALIZATION, ]: - + if start_status == ReplicaStartupStatus.PENDING_INITIALIZATION: + self._deployment_scheduler.on_replica_running( + self._name, replica.replica_tag, replica.actor_node_id + ) is_slow = time.time() - replica._start_time > SLOW_STARTUP_WARNING_S if is_slow: slow_replicas.append((replica, start_status)) @@ -1764,7 +1773,14 @@ def _check_startup_replicas( EXPONENTIAL_BACKOFF_FACTOR * self._backoff_time_s, MAX_BACKOFF_TIME_S ) - return slow_replicas, transitioned_to_running + return slow_replicas + + def stop_replicas(self, replicas_to_stop) -> None: + for replica in self._replicas.pop(): + if replica.replica_tag in replicas_to_stop: + self._stop_replica(replica) + else: + self._replicas.add(replica.actor_details.state, replica) def _stop_replica(self, replica, graceful_stop=True): """Stop replica @@ -1772,8 +1788,13 @@ def _stop_replica(self, replica, graceful_stop=True): 2. Change the replica into stopping state. 3. Set the health replica stats to 0. """ + logger.debug( + f"Adding STOPPING to replica_tag: {replica}, " + f"deployment_name: {self._name}" + ) replica.stop(graceful=graceful_stop) self._replicas.add(ReplicaState.STOPPING, replica) + self._deployment_scheduler.on_replica_stopping(self._name, replica.replica_tag) self.health_check_gauge.set( 0, tags={ @@ -1783,16 +1804,13 @@ def _stop_replica(self, replica, graceful_stop=True): }, ) - def _check_and_update_replicas(self) -> bool: + def _check_and_update_replicas(self): """ Check current state of all DeploymentReplica being tracked, and compare with state container from previous update() cycle to see if any state transition happened. - - Returns if any running replicas transitioned to another state. """ - running_replicas_changed = False for replica in self._replicas.pop(states=[ReplicaState.RUNNING]): if replica.check_health(): self._replicas.add(ReplicaState.RUNNING, replica) @@ -1805,7 +1823,6 @@ def _check_and_update_replicas(self) -> bool: }, ) else: - running_replicas_changed = True logger.warning( f"Replica {replica.replica_tag} of deployment " f"{self._name} failed health check, stopping it." @@ -1832,23 +1849,13 @@ def _check_and_update_replicas(self) -> bool: ) slow_start_replicas = [] - slow_start, starting_to_running = self._check_startup_replicas( - ReplicaState.STARTING - ) - slow_update, updating_to_running = self._check_startup_replicas( - ReplicaState.UPDATING - ) - slow_recover, recovering_to_running = self._check_startup_replicas( + slow_start = self._check_startup_replicas(ReplicaState.STARTING) + slow_update = self._check_startup_replicas(ReplicaState.UPDATING) + slow_recover = self._check_startup_replicas( ReplicaState.RECOVERING, stop_on_slow=True ) slow_start_replicas = slow_start + slow_update + slow_recover - running_replicas_changed = ( - running_replicas_changed - or starting_to_running - or updating_to_running - or recovering_to_running - ) if ( len(slow_start_replicas) @@ -1914,35 +1921,26 @@ def _check_and_update_replicas(self) -> bool: if not stopped: self._replicas.add(ReplicaState.STOPPING, replica) - return running_replicas_changed - - def update(self) -> Tuple[bool, bool]: + def update(self) -> DeploymentStateUpdateResult: """Attempts to reconcile this deployment to match its goal state. This is an asynchronous call; it's expected to be called repeatedly. Also updates the internal DeploymentStatusInfo based on the current state of the system. - - Returns (deleted, any_replicas_recovering). """ deleted, any_replicas_recovering = False, False + upscale = [] + downscale = None try: # Add or remove DeploymentReplica instances in self._replicas. # This should be the only place we adjust total number of replicas # we manage. - running_replicas_changed = self._scale_deployment_replicas() - # Check the state of existing replicas and transition if necessary. - running_replicas_changed |= self._check_and_update_replicas() - - # Check if the model_id has changed. - running_replicas_changed |= self._multiplexed_model_ids_updated + self._check_and_update_replicas() - if running_replicas_changed: - self._notify_running_replicas_changed() - self._multiplexed_model_ids_updated = False + upscale, downscale = self._scale_deployment_replicas() deleted, any_replicas_recovering = self._check_curr_status() except Exception: @@ -1956,7 +1954,12 @@ def update(self) -> Tuple[bool, bool]: message="Failed to update deployment:" f"\n{traceback.format_exc()}", ) - return deleted, any_replicas_recovering + return DeploymentStateUpdateResult( + deleted=deleted, + any_replicas_recovering=any_replicas_recovering, + upscale=upscale, + downscale=downscale, + ) def record_multiplexed_model_ids( self, replica_name: str, multiplexed_model_ids: List[str] @@ -1993,11 +1996,17 @@ def __init__( controller_name: str, detached: bool, long_poll_host: LongPollHost, + deployment_scheduler: deployment_scheduler.DeploymentScheduler, _save_checkpoint_func: Callable, gcs_client: GcsClient = None, ): super().__init__( - name, controller_name, detached, long_poll_host, _save_checkpoint_func + name, + controller_name, + detached, + long_poll_host, + deployment_scheduler, + _save_checkpoint_func, ) if gcs_client: self._gcs_client = gcs_client @@ -2008,25 +2017,26 @@ def _get_all_node_ids(self): # Test mock purpose return get_all_node_ids(self._gcs_client) - def _deploy_driver(self) -> bool: + def _deploy_driver(self) -> List[ReplicaSchedulingRequest]: """Deploy the driver deployment to each node.""" - all_nodes = self._get_all_node_ids() - deployed_nodes = set() - for replica in self._replicas.get( - [ - ReplicaState.STARTING, - ReplicaState.RUNNING, - ReplicaState.RECOVERING, - ReplicaState.UPDATING, - ReplicaState.STOPPING, - ] - ): - if replica.actor_node_id: - deployed_nodes.add(replica.actor_node_id) - replica_changed = False - for node_id, _ in all_nodes: - if node_id in deployed_nodes: - continue + num_running_replicas = self._replicas.count(states=[ReplicaState.RUNNING]) + if num_running_replicas >= self._target_state.num_replicas: + # Cancel starting replicas when driver deployment state creates + # more replicas than alive nodes. + # For example, get_all_node_ids returns 4 nodes when + # the driver deployment state decides the target number of replicas + # but later on when the deployment scheduler schedules these 4 replicas, + # there are only 3 alive nodes (1 node dies in between). + # In this case, 1 replica will be in the PENDING_ALLOCATION and we + # cancel it here. + for replica in self._replicas.pop(states=[ReplicaState.STARTING]): + self._stop_replica(replica) + + return [] + + upscale = [] + num_existing_replicas = self._replicas.count() + for _ in range(self._target_state.num_replicas - num_existing_replicas): replica_name = ReplicaName(self._name, get_random_letters()) new_deployment_replica = DeploymentReplica( self._controller_name, @@ -2034,13 +2044,12 @@ def _deploy_driver(self) -> bool: replica_name.replica_tag, replica_name.deployment_tag, self._target_state.version, - NodeAffinitySchedulingStrategy(node_id, soft=False), ) - new_deployment_replica.start(self._target_state.info) + upscale.append(new_deployment_replica.start(self._target_state.info)) self._replicas.add(ReplicaState.STARTING, new_deployment_replica) - replica_changed = True - return replica_changed + + return upscale def _stop_all_replicas(self) -> bool: replica_changed = False @@ -2049,6 +2058,7 @@ def _stop_all_replicas(self) -> bool: ReplicaState.STARTING, ReplicaState.RUNNING, ReplicaState.RECOVERING, + ReplicaState.UPDATING, ] ): self._stop_replica(replica) @@ -2056,8 +2066,8 @@ def _stop_all_replicas(self) -> bool: return replica_changed def _calculate_max_replicas_to_stop(self) -> int: - nums_nodes = len(self._get_all_node_ids()) - rollout_size = max(int(0.2 * nums_nodes), 1) + num_nodes = len(self._get_all_node_ids()) + rollout_size = max(int(0.2 * num_nodes), 1) old_running_replicas = self._replicas.count( exclude_version=self._target_state.version, states=[ReplicaState.STARTING, ReplicaState.UPDATING, ReplicaState.RUNNING], @@ -2065,12 +2075,14 @@ def _calculate_max_replicas_to_stop(self) -> int: new_running_replicas = self._replicas.count( version=self._target_state.version, states=[ReplicaState.RUNNING] ) - pending_replicas = nums_nodes - new_running_replicas - old_running_replicas + pending_replicas = num_nodes - new_running_replicas - old_running_replicas return max(rollout_size - pending_replicas, 0) - def update(self) -> Tuple[bool, bool]: - """Returns (deleted, any_replicas_recovering).""" + def update(self) -> DeploymentStateUpdateResult: try: + self._check_and_update_replicas() + + upscale = [] if self._target_state.deleting: self._stop_all_replicas() else: @@ -2085,18 +2097,28 @@ def update(self) -> Tuple[bool, bool]: if new_config.version is None: new_config.version = self._target_state.version.code_version self._set_target_state(new_config) + max_to_stop = self._calculate_max_replicas_to_stop() self._stop_or_update_outdated_version_replicas(max_to_stop) - self._deploy_driver() - self._check_and_update_replicas() - return self._check_curr_status() + + upscale = self._deploy_driver() + + deleted, any_replicas_recovering = self._check_curr_status() + return DeploymentStateUpdateResult( + deleted=deleted, + any_replicas_recovering=any_replicas_recovering, + upscale=upscale, + downscale=None, + ) except Exception: self._curr_status_info = DeploymentStatusInfo( name=self._name, status=DeploymentStatus.UNHEALTHY, message="Failed to update deployment:" f"\n{traceback.format_exc()}", ) - return False, False + return DeploymentStateUpdateResult( + deleted=False, any_replicas_recovering=False, upscale=[], downscale=None + ) def should_autoscale(self) -> bool: return False @@ -2122,24 +2144,7 @@ def __init__( self._detached = detached self._kv_store = kv_store self._long_poll_host = long_poll_host - - self._create_deployment_state: Callable = lambda name: DeploymentState( - name, - controller_name, - detached, - long_poll_host, - self._save_checkpoint_func, - ) - - self._create_driver_deployment_state: Callable = ( - lambda name: DriverDeploymentState( - name, - controller_name, - detached, - long_poll_host, - self._save_checkpoint_func, - ) - ) + self._deployment_scheduler = deployment_scheduler.DeploymentScheduler() self._deployment_states: Dict[str, DeploymentState] = dict() self._deleted_deployment_metadata: Dict[str, DeploymentInfo] = OrderedDict() @@ -2150,6 +2155,34 @@ def __init__( self.autoscaling_metrics_store = InMemoryMetricsStore() self.handle_metrics_store = InMemoryMetricsStore() + def _create_driver_deployment_state(self, name): + self._deployment_scheduler.on_deployment_created( + name, DriverDeploymentSchedulingPolicy() + ) + + return DriverDeploymentState( + name, + self._controller_name, + self._detached, + self._long_poll_host, + self._deployment_scheduler, + self._save_checkpoint_func, + ) + + def _create_deployment_state(self, name): + self._deployment_scheduler.on_deployment_created( + name, SpreadDeploymentSchedulingPolicy() + ) + + return DeploymentState( + name, + self._controller_name, + self._detached, + self._long_poll_host, + self._deployment_scheduler, + self._save_checkpoint_func, + ) + def record_autoscaling_metrics(self, data: Dict[str, float], send_timestamp: float): self.autoscaling_metrics_store.add_metrics_point(data, send_timestamp) @@ -2442,6 +2475,9 @@ def update(self) -> bool: """ deleted_tags = [] any_recovering = False + upscales = {} + downscales = {} + for deployment_name, deployment_state in self._deployment_states.items(): if deployment_state.should_autoscale(): current_num_ongoing_requests = self.get_replica_ongoing_request_metrics( @@ -2455,8 +2491,14 @@ def update(self) -> bool: deployment_state.autoscale( current_num_ongoing_requests, current_handle_queued_queries ) - deleted, recovering = deployment_state.update() - if deleted: + + deployment_state_update_result = deployment_state.update() + if deployment_state_update_result.upscale: + upscales[deployment_name] = deployment_state_update_result.upscale + if deployment_state_update_result.downscale: + downscales[deployment_name] = deployment_state_update_result.downscale + + if deployment_state_update_result.deleted: deleted_tags.append(deployment_name) deployment_info = deployment_state.target_info deployment_info.end_time_ms = int(time.time() * 1000) @@ -2464,9 +2506,19 @@ def update(self) -> bool: self._deleted_deployment_metadata.popitem(last=False) self._deleted_deployment_metadata[deployment_name] = deployment_info - any_recovering |= recovering + any_recovering |= deployment_state_update_result.any_replicas_recovering + + deployment_to_replicas_to_stop = self._deployment_scheduler.schedule( + upscales, downscales + ) + for deployment_name, replicas_to_stop in deployment_to_replicas_to_stop.items(): + self._deployment_states[deployment_name].stop_replicas(replicas_to_stop) + + for deployment_name, deployment_state in self._deployment_states.items(): + deployment_state.notify_running_replicas_changed() for tag in deleted_tags: + self._deployment_scheduler.on_deployment_deleted(tag) del self._deployment_states[tag] if len(deleted_tags): diff --git a/python/ray/serve/tests/test_deployment_scheduler.py b/python/ray/serve/tests/test_deployment_scheduler.py new file mode 100644 index 0000000000000..66e469b8bed13 --- /dev/null +++ b/python/ray/serve/tests/test_deployment_scheduler.py @@ -0,0 +1,270 @@ +import sys + +import pytest + +import ray +from ray.tests.conftest import * # noqa +from ray.serve._private.deployment_scheduler import ( + DeploymentScheduler, + SpreadDeploymentSchedulingPolicy, + DriverDeploymentSchedulingPolicy, + ReplicaSchedulingRequest, + DeploymentDownscaleRequest, +) + + +@ray.remote(num_cpus=1) +class Replica: + def get_node_id(self): + return ray.get_runtime_context().get_node_id() + + +def test_spread_deployment_scheduling_policy_upscale(ray_start_cluster): + """Test to make sure replicas are spreaded.""" + cluster = ray_start_cluster + cluster.add_node(num_cpus=3) + cluster.add_node(num_cpus=3) + cluster.wait_for_nodes() + ray.init(address=cluster.address) + + scheduler = DeploymentScheduler() + scheduler.on_deployment_created("deployment1", SpreadDeploymentSchedulingPolicy()) + replica_actor_handles = [] + deployment_to_replicas_to_stop = scheduler.schedule( + upscales={ + "deployment1": [ + ReplicaSchedulingRequest( + deployment_name="deployment1", + replica_name="replica1", + actor_def=Replica, + actor_resources={"CPU": 1}, + actor_options={}, + actor_init_args=(), + on_scheduled=lambda actor_handle: replica_actor_handles.append( + actor_handle + ), + ), + ReplicaSchedulingRequest( + deployment_name="deployment1", + replica_name="replica2", + actor_def=Replica, + actor_resources={"CPU": 1}, + actor_options={}, + actor_init_args=(), + on_scheduled=lambda actor_handle: replica_actor_handles.append( + actor_handle + ), + ), + ] + }, + downscales={}, + ) + assert not deployment_to_replicas_to_stop + assert len(replica_actor_handles) == 2 + assert not scheduler._pending_replicas["deployment1"] + assert len(scheduler._launching_replicas["deployment1"]) == 2 + assert ( + len( + { + ray.get(replica_actor_handles[0].get_node_id.remote()), + ray.get(replica_actor_handles[1].get_node_id.remote()), + } + ) + == 2 + ) + scheduler.on_replica_stopping("deployment1", "replica1") + scheduler.on_replica_stopping("deployment1", "replica2") + scheduler.on_deployment_deleted("deployment1") + + +def test_spread_deployment_scheduling_policy_downscale(ray_start_cluster): + """Test to make sure downscale prefers replicas without node id + and then replicas with fewest copies on a node. + """ + cluster = ray_start_cluster + cluster.add_node(num_cpus=3) + cluster.wait_for_nodes() + ray.init(address=cluster.address) + + scheduler = DeploymentScheduler() + scheduler.on_deployment_created("deployment1", SpreadDeploymentSchedulingPolicy()) + scheduler.on_replica_running("deployment1", "replica1", "node1") + scheduler.on_replica_running("deployment1", "replica2", "node1") + scheduler.on_replica_running("deployment1", "replica3", "node2") + scheduler.on_replica_recovering("deployment1", "replica4") + deployment_to_replicas_to_stop = scheduler.schedule( + upscales={}, + downscales={ + "deployment1": DeploymentDownscaleRequest( + deployment_name="deployment1", num_to_stop=1 + ) + }, + ) + assert len(deployment_to_replicas_to_stop) == 1 + # Prefer replica without node id + assert deployment_to_replicas_to_stop["deployment1"] == {"replica4"} + scheduler.on_replica_stopping("deployment1", "replica4") + + deployment_to_replicas_to_stop = scheduler.schedule( + upscales={ + "deployment1": [ + ReplicaSchedulingRequest( + deployment_name="deployment1", + replica_name="replica5", + actor_def=Replica, + actor_resources={"CPU": 1}, + actor_options={}, + actor_init_args=(), + on_scheduled=lambda actor_handle: actor_handle, + ), + ] + }, + downscales={}, + ) + assert not deployment_to_replicas_to_stop + deployment_to_replicas_to_stop = scheduler.schedule( + upscales={}, + downscales={ + "deployment1": DeploymentDownscaleRequest( + deployment_name="deployment1", num_to_stop=1 + ) + }, + ) + assert len(deployment_to_replicas_to_stop) == 1 + # Prefer replica without node id + assert deployment_to_replicas_to_stop["deployment1"] == {"replica5"} + scheduler.on_replica_stopping("deployment1", "replica5") + + deployment_to_replicas_to_stop = scheduler.schedule( + upscales={}, + downscales={ + "deployment1": DeploymentDownscaleRequest( + deployment_name="deployment1", num_to_stop=1 + ) + }, + ) + assert len(deployment_to_replicas_to_stop) == 1 + # Prefer replica that has fewest copies on a node + assert deployment_to_replicas_to_stop["deployment1"] == {"replica3"} + scheduler.on_replica_stopping("deployment1", "replica3") + + deployment_to_replicas_to_stop = scheduler.schedule( + upscales={}, + downscales={ + "deployment1": DeploymentDownscaleRequest( + deployment_name="deployment1", num_to_stop=2 + ) + }, + ) + assert len(deployment_to_replicas_to_stop) == 1 + # Prefer replica that has fewest copies on a node + assert deployment_to_replicas_to_stop["deployment1"] == {"replica1", "replica2"} + scheduler.on_replica_stopping("deployment1", "replica1") + scheduler.on_replica_stopping("deployment1", "replica2") + scheduler.on_deployment_deleted("deployment1") + + +def test_driver_deployment_scheduling_policy_upscale(ray_start_cluster): + """Test to make sure there is only one replica on each node + for the driver deployment. + """ + cluster = ray_start_cluster + cluster.add_node(num_cpus=3) + cluster.add_node(num_cpus=3) + cluster.wait_for_nodes() + ray.init(address=cluster.address) + + scheduler = DeploymentScheduler() + scheduler.on_deployment_created("deployment1", DriverDeploymentSchedulingPolicy()) + + replica_actor_handles = [] + deployment_to_replicas_to_stop = scheduler.schedule( + upscales={ + "deployment1": [ + ReplicaSchedulingRequest( + deployment_name="deployment1", + replica_name="replica1", + actor_def=Replica, + actor_resources={"CPU": 1}, + actor_options={}, + actor_init_args=(), + on_scheduled=lambda actor_handle: replica_actor_handles.append( + actor_handle + ), + ), + ReplicaSchedulingRequest( + deployment_name="deployment1", + replica_name="replica2", + actor_def=Replica, + actor_resources={"CPU": 1}, + actor_options={}, + actor_init_args=(), + on_scheduled=lambda actor_handle: replica_actor_handles.append( + actor_handle + ), + ), + ReplicaSchedulingRequest( + deployment_name="deployment1", + replica_name="replica3", + actor_def=Replica, + actor_resources={"CPU": 1}, + actor_options={}, + actor_init_args=(), + on_scheduled=lambda actor_handle: replica_actor_handles.append( + actor_handle + ), + ), + ] + }, + downscales={}, + ) + assert not deployment_to_replicas_to_stop + # 2 out of 3 replicas are scheduled since there are only two nodes in the cluster. + assert len(replica_actor_handles) == 2 + assert len(scheduler._pending_replicas["deployment1"]) == 1 + assert len(scheduler._launching_replicas["deployment1"]) == 2 + assert ( + len( + { + ray.get(replica_actor_handles[0].get_node_id.remote()), + ray.get(replica_actor_handles[1].get_node_id.remote()), + } + ) + == 2 + ) + + scheduler.on_replica_recovering("deployment1", "replica4") + cluster.add_node(num_cpus=3) + cluster.wait_for_nodes() + + deployment_to_replicas_to_stop = scheduler.schedule(upscales={}, downscales={}) + assert not deployment_to_replicas_to_stop + # No schduling while some replica is recovering + assert len(replica_actor_handles) == 2 + + scheduler.on_replica_stopping("deployment1", "replica4") + # The last replica is scheduled + deployment_to_replicas_to_stop = scheduler.schedule(upscales={}, downscales={}) + assert not deployment_to_replicas_to_stop + assert not scheduler._pending_replicas["deployment1"] + assert len(scheduler._launching_replicas["deployment1"]) == 3 + assert len(replica_actor_handles) == 3 + assert ( + len( + { + ray.get(replica_actor_handles[0].get_node_id.remote()), + ray.get(replica_actor_handles[1].get_node_id.remote()), + ray.get(replica_actor_handles[2].get_node_id.remote()), + } + ) + == 3 + ) + + scheduler.on_replica_stopping("deployment1", "replica1") + scheduler.on_replica_stopping("deployment1", "replica2") + scheduler.on_replica_stopping("deployment1", "replica3") + scheduler.on_deployment_deleted("deployment1") + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_deployment_state.py b/python/ray/serve/tests/test_deployment_state.py index 187ebe5a5ca25..97bddee4577c4 100644 --- a/python/ray/serve/tests/test_deployment_state.py +++ b/python/ray/serve/tests/test_deployment_state.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass import sys import time from typing import Any, Dict, List, Optional, Tuple from unittest.mock import patch, Mock +from collections import defaultdict import pytest @@ -16,6 +16,9 @@ ReplicaName, ReplicaState, ) +from ray.serve._private.deployment_scheduler import ( + ReplicaSchedulingRequest, +) from ray.serve._private.deployment_state import ( ActorReplicaWrapper, DeploymentState, @@ -26,7 +29,6 @@ ReplicaStartupStatus, ReplicaStateContainer, VersionedReplica, - rank_replicas_for_stopping, ) from ray.serve._private.constants import ( DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_S, @@ -37,7 +39,6 @@ ) from ray.serve._private.storage.kv_store import RayInternalKVStore from ray.serve._private.utils import get_random_letters -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy class FakeRemoteFunction: @@ -72,8 +73,6 @@ def __init__( replica_tag: ReplicaTag, deployment_name: str, version: DeploymentVersion, - scheduling_strategy="SPREAD", - node_id=None, ): self._actor_name = actor_name self._replica_tag = replica_tag @@ -98,9 +97,8 @@ def __init__( # Returned by the health check. self.healthy = True self._is_cross_language = False - self._scheduling_strategy = scheduling_strategy self._actor_handle = MockActorHandle() - self._node_id = node_id + self._node_id = None self._node_id_is_set = False @property @@ -151,8 +149,6 @@ def worker_id(self) -> Optional[str]: def node_id(self) -> Optional[str]: if self._node_id_is_set: return self._node_id - if isinstance(self._scheduling_strategy, NodeAffinitySchedulingStrategy): - return self._scheduling_strategy.node_id if self.ready == ReplicaStartupStatus.SUCCEEDED or self.started: return "node-id" return None @@ -191,6 +187,15 @@ def set_node_id(self, node_id: str): def start(self, deployment_info: DeploymentInfo): self.started = True + return ReplicaSchedulingRequest( + deployment_name=self._deployment_name, + replica_name=self._replica_tag, + actor_def=None, + actor_resources=None, + actor_options=None, + actor_init_args=None, + on_scheduled=None, + ) def reconfigure(self, version: DeploymentVersion): self.started = True @@ -239,10 +244,50 @@ def check_health(self): self.health_check_called = True return self.healthy - def set_scheduling_strategy( - self, scheduling_strategy: NodeAffinitySchedulingStrategy - ): - self._scheduling_strategy = scheduling_strategy + +class MockDeploymentScheduler: + def __init__(self): + self.deployments = set() + self.replicas = defaultdict(set) + + def on_deployment_created(self, deployment_name, scheduling_strategy): + assert deployment_name not in self.deployments + self.deployments.add(deployment_name) + + def on_deployment_deleted(self, deployment_name): + assert deployment_name in self.deployments + self.deployments.remove(deployment_name) + + def on_replica_stopping(self, deployment_name, replica_name): + assert replica_name in self.replicas[deployment_name] + self.replicas[deployment_name].remove(replica_name) + + def on_replica_running(self, deployment_name, replica_name, node_id): + assert replica_name in self.replicas[deployment_name] + + def on_replica_recovering(self, deployment_name, replica_name): + assert replica_name not in self.replicas[deployment_name] + self.replicas[deployment_name].add(replica_name) + + def schedule(self, upscales, downscales): + for upscale in upscales.values(): + for replica_scheduling_request in upscale: + assert ( + replica_scheduling_request.replica_name + not in self.replicas[replica_scheduling_request.deployment_name] + ) + self.replicas[replica_scheduling_request.deployment_name].add( + replica_scheduling_request.replica_name + ) + + deployment_to_replicas_to_stop = defaultdict(set) + for downscale in downscales.values(): + replica_iter = iter(self.replicas[downscale.deployment_name]) + for _ in range(downscale.num_to_stop): + deployment_to_replicas_to_stop[downscale.deployment_name].add( + next(replica_iter) + ) + return deployment_to_replicas_to_stop class MockKVStore: @@ -339,13 +384,19 @@ def mock_save_checkpoint_fn(*args, **kwargs): "name", True, mock_long_poll, + MockDeploymentScheduler(), mock_save_checkpoint_fn, mock_client, ) yield deployment_state, timer else: deployment_state = DeploymentState( - "name", "name", True, mock_long_poll, mock_save_checkpoint_fn + "name", + "name", + True, + mock_long_poll, + MockDeploymentScheduler(), + mock_save_checkpoint_fn, ) yield deployment_state, timer @@ -595,7 +646,10 @@ def test_create_delete_single_replica(mock_get_all_node_ids, mock_deployment_sta assert updating # Single replica should be created. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=1, by_state=[(ReplicaState.STARTING, 1)]) # update() should not transition the state if the replica isn't ready. @@ -611,7 +665,14 @@ def test_create_delete_single_replica(mock_get_all_node_ids, mock_deployment_sta # Removing the replica should transition it to stopping. deployment_state.delete() - deployment_state.update() + deployment_state_update_result = deployment_state.update() + replicas_to_stop = deployment_state._deployment_scheduler.schedule( + {}, + {deployment_state._name: deployment_state_update_result.downscale} + if deployment_state_update_result.downscale + else {}, + )[deployment_state._name] + deployment_state.stop_replicas(replicas_to_stop) check_counts(deployment_state, total=1, by_state=[(ReplicaState.STOPPING, 1)]) assert deployment_state._replicas.get()[0]._actor.stopped assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -619,8 +680,8 @@ def test_create_delete_single_replica(mock_get_all_node_ids, mock_deployment_sta # Once it's done stopping, replica should be removed. replica = deployment_state._replicas.get()[0] replica._actor.set_done_stopping() - deleted, _ = deployment_state.update() - assert deleted + deployment_state_update_result = deployment_state.update() + assert deployment_state_update_result.deleted check_counts(deployment_state, total=0) @@ -635,11 +696,21 @@ def test_force_kill(mock_get_all_node_ids, mock_deployment_state): # Create and delete the deployment. deployment_state.deploy(b_info_1) - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) deployment_state._replicas.get()[0]._actor.set_ready() deployment_state.update() deployment_state.delete() - deployment_state.update() + deployment_state_update_result = deployment_state.update() + replicas_to_stop = deployment_state._deployment_scheduler.schedule( + {}, + {deployment_state._name: deployment_state_update_result.downscale} + if deployment_state_update_result.downscale + else {}, + )[deployment_state._name] + deployment_state.stop_replicas(replicas_to_stop) # Replica should remain in STOPPING until it finishes. check_counts(deployment_state, total=1, by_state=[(ReplicaState.STOPPING, 1)]) @@ -668,8 +739,8 @@ def test_force_kill(mock_get_all_node_ids, mock_deployment_state): # Once the replica is done stopping, it should be removed. replica = deployment_state._replicas.get()[0] replica._actor.set_done_stopping() - deleted, _ = deployment_state.update() - assert deleted + deployment_state_update_result = deployment_state.update() + assert deployment_state_update_result.deleted check_counts(deployment_state, total=0) @@ -684,7 +755,10 @@ def test_redeploy_same_version(mock_get_all_node_ids, mock_deployment_state): updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, version=b_version_1, @@ -742,7 +816,10 @@ def test_redeploy_no_version(mock_get_all_node_ids, mock_deployment_state): updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=1, by_state=[(ReplicaState.STARTING, 1)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -760,11 +837,11 @@ def test_redeploy_no_version(mock_get_all_node_ids, mock_deployment_state): deployment_state._replicas.get(states=[ReplicaState.STOPPING])[ 0 ]._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=0) - # Now that the old replica has stopped, the new replica should be started. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=1, by_state=[(ReplicaState.STARTING, 1)]) deployment_state._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready() assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -792,16 +869,16 @@ def test_redeploy_no_version(mock_get_all_node_ids, mock_deployment_state): 0 ]._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=0) - - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) deployment_state._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready() check_counts(deployment_state, total=1, by_state=[(ReplicaState.STARTING, 1)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING - deleted, _ = deployment_state.update() - assert not deleted + deployment_state_update_result = deployment_state.update() + assert not deployment_state_update_result.deleted check_counts(deployment_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) assert deployment_state.curr_status_info.status == DeploymentStatus.HEALTHY @@ -817,7 +894,10 @@ def test_redeploy_new_version(mock_get_all_node_ids, mock_deployment_state): updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, version=b_version_1, @@ -846,11 +926,12 @@ def test_redeploy_new_version(mock_get_all_node_ids, mock_deployment_state): deployment_state._replicas.get(states=[ReplicaState.STOPPING])[ 0 ]._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=0) # Now that the old replica has stopped, the new replica should be started. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, version=b_version_2, @@ -892,10 +973,10 @@ def test_redeploy_new_version(mock_get_all_node_ids, mock_deployment_state): 0 ]._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=0) - - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) deployment_state._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready() check_counts( deployment_state, @@ -904,8 +985,8 @@ def test_redeploy_new_version(mock_get_all_node_ids, mock_deployment_state): by_state=[(ReplicaState.STARTING, 1)], ) - deleted, _ = deployment_state.update() - assert not deleted + deployment_state_update_result = deployment_state.update() + assert not deployment_state_update_result.deleted check_counts( deployment_state, version=b_version_3, @@ -942,7 +1023,10 @@ def test_deploy_new_config_same_code_version( assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING # Create the replica initially. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) deployment_state._replicas.get()[0]._actor.set_ready() deployment_state.update() check_counts( @@ -1003,7 +1087,10 @@ def test_deploy_new_config_same_code_version_2( assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING # Create the replica initially. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, version=b_version_1, @@ -1028,13 +1115,6 @@ def test_deploy_new_config_same_code_version_2( deployment_state._replicas.get()[0]._actor.set_ready() deployment_state.update() - check_counts( - deployment_state, - version=b_version_1, - total=1, - by_state=[(ReplicaState.RUNNING, 1)], - ) - deployment_state.update() check_counts( deployment_state, version=b_version_2, @@ -1067,7 +1147,10 @@ def test_deploy_new_config_new_version(mock_get_all_node_ids, mock_deployment_st assert updating # Create the replica initially. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) deployment_state._replicas.get()[0]._actor.set_ready() deployment_state.update() check_counts( @@ -1094,13 +1177,12 @@ def test_deploy_new_config_new_version(mock_get_all_node_ids, mock_deployment_st deployment_state._replicas.get(states=[ReplicaState.STOPPING])[ 0 ]._actor.set_done_stopping() - deployment_state.update() - assert deployment_state._replicas.count() == 0 - check_counts(deployment_state, total=0) - assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING # Now the new version should be started. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) deployment_state._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready() check_counts( deployment_state, @@ -1108,6 +1190,7 @@ def test_deploy_new_config_new_version(mock_get_all_node_ids, mock_deployment_st total=1, by_state=[(ReplicaState.STARTING, 1)], ) + assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING # Check that the new version is now running. deployment_state.update() @@ -1131,7 +1214,10 @@ def test_initial_deploy_no_throttling(mock_get_all_node_ids, mock_deployment_sta updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=10, by_state=[(ReplicaState.STARTING, 10)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -1159,7 +1245,10 @@ def test_new_version_deploy_throttling(mock_get_all_node_ids, mock_deployment_st updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=10, by_state=[(ReplicaState.STARTING, 10)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -1190,16 +1279,11 @@ def test_new_version_deploy_throttling(mock_get_all_node_ids, mock_deployment_st 0 ]._actor.set_done_stopping() - deployment_state.update() - check_counts( - deployment_state, - version=b_version_1, - total=9, - by_state=[(ReplicaState.RUNNING, 8), (ReplicaState.STOPPING, 1)], - ) - # Now one of the new version replicas should start up. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=10) check_counts( deployment_state, @@ -1217,7 +1301,6 @@ def test_new_version_deploy_throttling(mock_get_all_node_ids, mock_deployment_st # Mark the new version replica as ready. Another old version replica # should subsequently be stopped. deployment_state._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready() - deployment_state.update() deployment_state.update() check_counts(deployment_state, total=10) @@ -1248,24 +1331,11 @@ def test_new_version_deploy_throttling(mock_get_all_node_ids, mock_deployment_st new_replicas = 1 old_replicas = 9 while old_replicas > 3: - deployment_state.update() - - check_counts(deployment_state, total=8) - check_counts( - deployment_state, - version=b_version_1, - total=old_replicas - 2, - by_state=[(ReplicaState.RUNNING, old_replicas - 2)], - ) - check_counts( - deployment_state, - version=b_version_2, - total=new_replicas, - by_state=[(ReplicaState.RUNNING, new_replicas)], - ) - # Replicas starting up. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=10) check_counts( deployment_state, @@ -1289,21 +1359,6 @@ def test_new_version_deploy_throttling(mock_get_all_node_ids, mock_deployment_st ]._actor.set_ready() new_replicas += 2 - deployment_state.update() - check_counts(deployment_state, total=10) - check_counts( - deployment_state, - version=b_version_1, - total=old_replicas - 2, - by_state=[(ReplicaState.RUNNING, old_replicas - 2)], - ) - check_counts( - deployment_state, - version=b_version_2, - total=new_replicas, - by_state=[(ReplicaState.RUNNING, new_replicas)], - ) - # Two more old replicas should be stopped. old_replicas -= 2 deployment_state.update() @@ -1334,23 +1389,11 @@ def test_new_version_deploy_throttling(mock_get_all_node_ids, mock_deployment_st assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING # 2 left to update. - deployment_state.update() - check_counts(deployment_state, total=8) - check_counts( - deployment_state, - version=b_version_1, - total=1, - by_state=[(ReplicaState.RUNNING, 1)], - ) - check_counts( - deployment_state, - version=b_version_2, - total=new_replicas, - by_state=[(ReplicaState.RUNNING, 7)], - ) - # Replicas starting up. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=10) check_counts( deployment_state, @@ -1369,22 +1412,6 @@ def test_new_version_deploy_throttling(mock_get_all_node_ids, mock_deployment_st deployment_state._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready() deployment_state._replicas.get(states=[ReplicaState.STARTING])[1]._actor.set_ready() - # One replica remaining to update. - deployment_state.update() - check_counts(deployment_state, total=10) - check_counts( - deployment_state, - version=b_version_1, - total=1, - by_state=[(ReplicaState.RUNNING, 1)], - ) - check_counts( - deployment_state, - version=b_version_2, - total=9, - by_state=[(ReplicaState.RUNNING, 9)], - ) - # The last replica should be stopped. deployment_state.update() check_counts(deployment_state, total=10) @@ -1405,17 +1432,11 @@ def test_new_version_deploy_throttling(mock_get_all_node_ids, mock_deployment_st 0 ]._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=9) - check_counts( - deployment_state, - version=b_version_2, - total=9, - by_state=[(ReplicaState.RUNNING, 9)], - ) - # The last replica should start up. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=10) check_counts( deployment_state, @@ -1453,7 +1474,10 @@ def test_reconfigure_throttling(mock_get_all_node_ids, mock_deployment_state): updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 2)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -1491,21 +1515,6 @@ def test_reconfigure_throttling(mock_get_all_node_ids, mock_deployment_state): deployment_state._replicas.get(states=[ReplicaState.UPDATING])[0]._actor.set_ready() # The updated replica should now be RUNNING. - deployment_state.update() - check_counts( - deployment_state, - version=b_version_1, - total=1, - by_state=[(ReplicaState.RUNNING, 1)], - ) - check_counts( - deployment_state, - version=b_version_2, - total=1, - by_state=[(ReplicaState.RUNNING, 1)], - ) - assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING - # The second replica should now be updated. deployment_state.update() check_counts( @@ -1543,7 +1552,10 @@ def test_new_version_and_scale_down(mock_get_all_node_ids, mock_deployment_state updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=10, by_state=[(ReplicaState.STARTING, 10)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -1560,7 +1572,11 @@ def test_new_version_and_scale_down(mock_get_all_node_ids, mock_deployment_state b_info_2, b_version_2 = deployment_info(num_replicas=2, version="2") updating = deployment_state.deploy(b_info_2) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + replicas_to_stop = deployment_state._deployment_scheduler.schedule( + {}, {deployment_state._name: deployment_state_update_result.downscale} + )[deployment_state._name] + deployment_state.stop_replicas(replicas_to_stop) check_counts( deployment_state, version=b_version_1, @@ -1589,15 +1605,6 @@ def test_new_version_and_scale_down(mock_get_all_node_ids, mock_deployment_state replica._actor.set_done_stopping() # Now the rolling update should trigger, stopping one of the old replicas. - deployment_state.update() - check_counts(deployment_state, total=2) - check_counts( - deployment_state, - version=b_version_1, - total=2, - by_state=[(ReplicaState.RUNNING, 2)], - ) - deployment_state.update() check_counts(deployment_state, total=2) check_counts( @@ -1611,17 +1618,11 @@ def test_new_version_and_scale_down(mock_get_all_node_ids, mock_deployment_state 0 ]._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=1) - check_counts( - deployment_state, - version=b_version_1, - total=1, - by_state=[(ReplicaState.RUNNING, 1)], - ) - # Old version stopped, new version should start up. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2) check_counts( deployment_state, @@ -1637,21 +1638,6 @@ def test_new_version_and_scale_down(mock_get_all_node_ids, mock_deployment_state ) deployment_state._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready() - deployment_state.update() - check_counts(deployment_state, total=2) - check_counts( - deployment_state, - version=b_version_1, - total=1, - by_state=[(ReplicaState.RUNNING, 1)], - ) - check_counts( - deployment_state, - version=b_version_2, - total=1, - by_state=[(ReplicaState.RUNNING, 1)], - ) - # New version is started, final old version replica should be stopped. deployment_state.update() check_counts(deployment_state, total=2) @@ -1671,18 +1657,13 @@ def test_new_version_and_scale_down(mock_get_all_node_ids, mock_deployment_state deployment_state._replicas.get(states=[ReplicaState.STOPPING])[ 0 ]._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=1) - check_counts( - deployment_state, - version=b_version_2, - total=1, - by_state=[(ReplicaState.RUNNING, 1)], - ) # Final old version replica is stopped, final new version replica # should be started. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2) check_counts( deployment_state, @@ -1714,7 +1695,10 @@ def test_new_version_and_scale_up(mock_deployment_state): updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 2)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -1731,7 +1715,10 @@ def test_new_version_and_scale_up(mock_deployment_state): b_info_2, b_version_2 = deployment_info(num_replicas=10, version="2") updating = deployment_state.deploy(b_info_2) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, version=b_version_1, @@ -1748,19 +1735,6 @@ def test_new_version_and_scale_up(mock_deployment_state): # Mark the new replicas as ready. for replica in deployment_state._replicas.get(states=[ReplicaState.STARTING]): replica._actor.set_ready() - deployment_state.update() - check_counts( - deployment_state, - version=b_version_1, - total=2, - by_state=[(ReplicaState.RUNNING, 2)], - ) - check_counts( - deployment_state, - version=b_version_2, - total=8, - by_state=[(ReplicaState.RUNNING, 8)], - ) # Now that the new version replicas are up, rolling update should start. deployment_state.update() @@ -1781,17 +1755,11 @@ def test_new_version_and_scale_up(mock_deployment_state): for replica in deployment_state._replicas.get(states=[ReplicaState.STOPPING]): replica._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=8) - check_counts( - deployment_state, - version=b_version_2, - total=8, - by_state=[(ReplicaState.RUNNING, 8)], - ) - # The remaining replicas should be started. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=10) check_counts( deployment_state, @@ -1828,7 +1796,10 @@ def test_health_check(mock_get_all_node_ids, mock_deployment_state): updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 2)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -1860,16 +1831,16 @@ def test_health_check(mock_get_all_node_ids, mock_deployment_state): replica = deployment_state._replicas.get(states=[ReplicaState.STOPPING])[0] replica._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - assert deployment_state.curr_status_info.status == DeploymentStatus.UNHEALTHY - - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, total=2, by_state=[(ReplicaState.RUNNING, 1), (ReplicaState.STARTING, 1)], ) + assert deployment_state.curr_status_info.status == DeploymentStatus.UNHEALTHY replica = deployment_state._replicas.get(states=[ReplicaState.STARTING])[0] replica._actor.set_ready() @@ -1890,7 +1861,10 @@ def test_update_while_unhealthy(mock_get_all_node_ids, mock_deployment_state): updating = deployment_state.deploy(b_info_1) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 2)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -1922,17 +1896,16 @@ def test_update_while_unhealthy(mock_get_all_node_ids, mock_deployment_state): replica = deployment_state._replicas.get(states=[ReplicaState.STOPPING])[0] replica._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - assert deployment_state.curr_status_info.status == DeploymentStatus.UNHEALTHY - # Now deploy a new version (e.g., a rollback). This should update the status # to UPDATING and then it should eventually become healthy. b_info_2, b_version_2 = deployment_info(num_replicas=2, version="2") updating = deployment_state.deploy(b_info_2) assert updating - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, version=b_version_1, @@ -1971,8 +1944,10 @@ def test_update_while_unhealthy(mock_get_all_node_ids, mock_deployment_state): replica._actor.set_done_stopping() # Another replica of the new version should get started. - deployment_state.update() - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, version=b_version_2, @@ -1998,8 +1973,11 @@ def test_update_while_unhealthy(mock_get_all_node_ids, mock_deployment_state): def _constructor_failure_loop_two_replica(deployment_state, num_loops): """Helper function to exact constructor failure loops.""" for i in range(num_loops): - # Single replica should be created. - deployment_state.update() + # Two replicas should be created. + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 2)]) assert deployment_state._replica_constructor_retry_counter == i * 2 @@ -2016,8 +1994,6 @@ def _constructor_failure_loop_two_replica(deployment_state, num_loops): # Once it's done stopping, replica should be removed. replica_1._actor.set_done_stopping() replica_2._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=0) @pytest.mark.parametrize("mock_deployment_state", [True, False], indirect=True) @@ -2041,7 +2017,7 @@ def test_deploy_with_consistent_constructor_failure( assert deployment_state._replica_constructor_retry_counter == 6 assert deployment_state.curr_status_info.status == DeploymentStatus.UNHEALTHY - check_counts(deployment_state, total=0) + check_counts(deployment_state, total=2) assert deployment_state.curr_status_info.message != "" @@ -2074,7 +2050,10 @@ def test_deploy_with_partial_constructor_failure( _constructor_failure_loop_two_replica(deployment_state, 2) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 2)]) assert deployment_state._replica_constructor_retry_counter == 4 assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -2095,12 +2074,11 @@ def test_deploy_with_partial_constructor_failure( check_counts(deployment_state, total=2, by_state=[(ReplicaState.STOPPING, 1)]) replica_2._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - check_counts(deployment_state, total=1, by_state=[(ReplicaState.STARTING, 0)]) - # New update cycle should spawn new replica after previous one is removed - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.RUNNING, 1)]) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 1)]) @@ -2123,11 +2101,10 @@ def test_deploy_with_partial_constructor_failure( starting_replica = deployment_state._replicas.get(states=[ReplicaState.STOPPING])[0] starting_replica._actor.set_done_stopping() - deployment_state.update() - check_counts(deployment_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - check_counts(deployment_state, total=1, by_state=[(ReplicaState.STARTING, 0)]) - - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.RUNNING, 1)]) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 1)]) @@ -2170,7 +2147,10 @@ def test_deploy_with_transient_constructor_failure( assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING # Let both replicas succeed in last try. - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 2)]) assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING @@ -2216,7 +2196,10 @@ def test_exponential_backoff(mock_get_all_node_ids, mock_deployment_state): # Set new replicas to fail consecutively check_counts(deployment_state, total=0) # No replicas - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) last_retry = timer.time() # This should be time at which replicas were retried check_counts(deployment_state, total=2) # Two new replicas replica_1 = deployment_state._replicas.get()[0] @@ -2251,7 +2234,12 @@ def mock_deployment_state_manager_full(request) -> Tuple[DeploymentStateManager, with patch( "ray.serve._private.deployment_state.ActorReplicaWrapper", new=MockReplicaActorWrapper, - ), patch("time.time", new=timer.time), patch( + ), patch( + "ray.serve._private.deployment_scheduler.DeploymentScheduler", + new=MockDeploymentScheduler, + ), patch( + "time.time", new=timer.time + ), patch( "ray.serve._private.long_poll.LongPollHost" ) as mock_long_poll, patch( "ray.serve._private.deployment_state.GcsClient" @@ -2452,8 +2440,6 @@ def test_recover_during_rolling_update( by_state=[(ReplicaState.STOPPING, 1)], ) new_mocked_replica._actor.set_done_stopping() - new_deployment_state_manager.update() - check_counts(new_deployment_state, total=0) # Now that the replica of version "1" has been stopped, a new # replica of version "2" should be started @@ -2484,7 +2470,12 @@ def mock_deployment_state_manager(request) -> Tuple[DeploymentStateManager, Mock with patch( "ray.serve._private.deployment_state.ActorReplicaWrapper", new=MockReplicaActorWrapper, - ), patch("time.time", new=timer.time), patch( + ), patch( + "ray.serve._private.deployment_scheduler.DeploymentScheduler", + new=MockDeploymentScheduler, + ), patch( + "time.time", new=timer.time + ), patch( "ray.serve._private.long_poll.LongPollHost" ) as mock_long_poll: @@ -2497,15 +2488,8 @@ def mock_deployment_state_manager(request) -> Tuple[DeploymentStateManager, Mock mock_long_poll, all_current_actor_names, ) - deployment_state = DeploymentState( - "test", - "name", - True, - mock_long_poll, - deployment_state_manager._save_checkpoint_func, - ) - yield deployment_state_manager, deployment_state, timer + yield deployment_state_manager, timer ray.shutdown() @@ -2515,19 +2499,19 @@ def test_shutdown(mock_deployment_state_manager, is_driver_deployment): Test that shutdown waits for all deployments to be deleted and they are force-killed without a grace period. """ - deployment_state_manager, deployment_state, timer = mock_deployment_state_manager + deployment_state_manager, timer = mock_deployment_state_manager - tag = "test" + deployment_name = "test" grace_period_s = 10 b_info_1, b_version_1 = deployment_info( graceful_shutdown_timeout_s=grace_period_s, is_driver_deployment=is_driver_deployment, ) - updating = deployment_state.deploy(b_info_1) + updating = deployment_state_manager.deploy(deployment_name, b_info_1) assert updating - deployment_state_manager._deployment_states[tag] = deployment_state + deployment_state = deployment_state_manager._deployment_states[deployment_name] # Single replica should be created. deployment_state_manager.update() @@ -2558,29 +2542,7 @@ def test_shutdown(mock_deployment_state_manager, is_driver_deployment): assert len(deployment_state_manager.get_deployment_statuses()) == 0 -def test_stopping_replicas_ranking(): - @dataclass - class MockReplica: - actor_node_id: str - - def compare(before, after): - before_replicas = [MockReplica(item) for item in before] - after_replicas = [MockReplica(item) for item in after] - result_replicas = rank_replicas_for_stopping(before_replicas) - assert result_replicas == after_replicas - - compare( - [None, 1, None], [None, None, 1] - ) # replicas not allocated should be stopped first - compare( - [3, 3, 3, 2, 2, 1], [1, 2, 2, 3, 3, 3] - ) # prefer to stop dangling replicas first - compare([2, 2, 3, 3], [2, 2, 3, 3]) # if equal, ordering should be kept - - -@pytest.mark.parametrize("mock_deployment_state", [True, False], indirect=True) -@patch.object(DriverDeploymentState, "_get_all_node_ids") -def test_resource_requirements_none(mock_get_all_node_ids, mock_deployment_state): +def test_resource_requirements_none(): """Ensure resource_requirements doesn't break if a requirement is None""" class FakeActor: @@ -2596,6 +2558,41 @@ class FakeActor: replica.resource_requirements() +@pytest.mark.parametrize("mock_deployment_state", [True], indirect=True) +@patch.object(DriverDeploymentState, "_get_all_node_ids") +def test_cancel_extra_replicas_for_driver_deployment( + mock_get_all_node_ids, mock_deployment_state +): + """Test to make sure the driver deployment state + can cancel extra starting replicas. + """ + + deployment_state, timer = mock_deployment_state + mock_get_all_node_ids.return_value = [("0", "0"), ("1", "1")] + + b_info_1, b_version_1 = deployment_info() + updating = deployment_state.deploy(b_info_1) + assert updating + assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING + + deployment_state_update_result = deployment_state.update() + # 1 node dies, now the cluster only has 1 node + mock_get_all_node_ids.return_value = [("0", "0")] + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) + check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 2)]) + # only 1 replica is scheduled successfully, the other is PENDING_ALLOCATION + deployment_state._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready() + # the other replica should be cancelled + deployment_state.update() + check_counts( + deployment_state, + total=2, + by_state=[(ReplicaState.RUNNING, 1), (ReplicaState.STOPPING, 1)], + ) + + @pytest.mark.parametrize("mock_deployment_state", [True], indirect=True) @patch.object(DriverDeploymentState, "_get_all_node_ids") def test_add_and_remove_nodes_for_driver_deployment( @@ -2610,12 +2607,18 @@ def test_add_and_remove_nodes_for_driver_deployment( assert updating assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=1, by_state=[(ReplicaState.STARTING, 1)]) # Add a node when previous one is in STARTING state mock_get_all_node_ids.return_value = [("0", "0"), ("1", "1")] - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts(deployment_state, total=2, by_state=[(ReplicaState.STARTING, 2)]) for replica in deployment_state._replicas.get(states=[ReplicaState.STARTING]): replica._actor.set_ready() @@ -2624,7 +2627,10 @@ def test_add_and_remove_nodes_for_driver_deployment( # Add another two nodes mock_get_all_node_ids.return_value = [(str(i), str(i)) for i in range(4)] - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, total=4, @@ -2642,24 +2648,28 @@ def test_add_and_remove_nodes_for_driver_deployment( deployment_state._replicas.get(states=[ReplicaState.RUNNING])[ 3 ]._actor.set_unhealthy() - deployment_state.update() + deployment_state_update_result = deployment_state.update() + deployment_state._deployment_scheduler.schedule( + {deployment_state._name: deployment_state_update_result.upscale}, {} + ) check_counts( deployment_state, - total=5, + total=4, by_state=[ (ReplicaState.RUNNING, 3), - (ReplicaState.STARTING, 1), (ReplicaState.STOPPING, 1), ], ) - # Mark stopped replica finish stopping step and starting replica - # finish starting step - for replica in deployment_state._replicas.get(states=[ReplicaState.STARTING]): - replica._actor.set_ready() + # Mark stopped replica finish stopping step. for replica in deployment_state._replicas.get(states=[ReplicaState.STOPPING]): replica._actor.set_done_stopping() deployment_state.update() + + # Make starting replica finish starting step. + for replica in deployment_state._replicas.get(states=[ReplicaState.STARTING]): + replica._actor.set_ready() + deployment_state.update() check_counts(deployment_state, total=4, by_state=[(ReplicaState.RUNNING, 4)]) @@ -2729,9 +2739,7 @@ def test_get_active_node_ids(mock_get_all_node_ids, mock_deployment_state_manage ) mocked_replicas = deployment_state._replicas.get() for idx, mocked_replica in enumerate(mocked_replicas): - mocked_replica._actor.set_scheduling_strategy( - NodeAffinitySchedulingStrategy(node_id=node_ids[idx], soft=True) - ) + mocked_replica._actor.set_node_id(node_ids[idx]) assert deployment_state.get_active_node_ids() == set(node_ids) assert deployment_state_manager.get_active_node_ids() == set(node_ids) @@ -2797,9 +2805,7 @@ def test_get_active_node_ids_none( ) mocked_replicas = deployment_state._replicas.get() for idx, mocked_replica in enumerate(mocked_replicas): - mocked_replica._actor.set_scheduling_strategy( - NodeAffinitySchedulingStrategy(node_id=node_ids[idx], soft=True) - ) + mocked_replica._actor.set_node_id(node_ids[idx]) assert deployment_state.get_active_node_ids() == set(node_ids) assert deployment_state_manager.get_active_node_ids() == set(node_ids)