diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index ecdfb227729..a4308c216bd 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -2047,7 +2047,7 @@ def _update_cluster_status( def _refresh_cluster_record( cluster_name: str, *, - force_refresh: bool = False, + force_refresh_statuses: Optional[Set[status_lib.ClusterStatus]] = None, acquire_per_cluster_status_lock: bool = True ) -> Optional[Dict[str, Any]]: """Refresh the cluster, and return the possibly updated record. @@ -2058,8 +2058,10 @@ def _refresh_cluster_record( Args: cluster_name: The name of the cluster. - force_refresh: if True, refresh the cluster status even if it may be - skipped. Otherwise (the default), only refresh if the cluster: + force_refresh_statuses: if specified, refresh the cluster if it has one of + the specified statuses. Additionally, clusters satisfying the + following conditions will always be refreshed no matter the + argument is specified or not: 1. is a spot cluster, or 2. is a non-spot cluster, is not STOPPED, and autostop is set. acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock @@ -2089,7 +2091,9 @@ def _refresh_cluster_record( use_spot = handle.launched_resources.use_spot has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and record['autostop'] >= 0) - if force_refresh or has_autostop or use_spot: + force_refresh_for_cluster = (force_refresh_statuses is not None and + record['status'] in force_refresh_statuses) + if force_refresh_for_cluster or has_autostop or use_spot: record = _update_cluster_status( cluster_name, acquire_per_cluster_status_lock=acquire_per_cluster_status_lock) @@ -2100,7 +2104,7 @@ def _refresh_cluster_record( def refresh_cluster_status_handle( cluster_name: str, *, - force_refresh: bool = False, + force_refresh_statuses: Optional[Set[status_lib.ClusterStatus]] = None, acquire_per_cluster_status_lock: bool = True, ) -> Tuple[Optional[status_lib.ClusterStatus], Optional[backends.ResourceHandle]]: @@ -2112,7 +2116,7 @@ def refresh_cluster_status_handle( """ record = _refresh_cluster_record( cluster_name, - force_refresh=force_refresh, + force_refresh_statuses=force_refresh_statuses, acquire_per_cluster_status_lock=acquire_per_cluster_status_lock) if record is None: return None, None @@ -2331,7 +2335,7 @@ def _refresh_cluster(cluster_name): try: record = _refresh_cluster_record( cluster_name, - force_refresh=True, + force_refresh_statuses=set(status_lib.ClusterStatus), acquire_per_cluster_status_lock=True) except (exceptions.ClusterStatusFetchingError, exceptions.CloudUserIdentityError, diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index d589487cffa..63d33a9a442 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3758,11 +3758,29 @@ def _check_existing_cluster( exceptions.InvalidClusterNameError: If the cluster name is invalid. # TODO(zhwu): complete the list of exceptions. """ - handle_before_refresh = global_user_state.get_handle_from_cluster_name( - cluster_name) + record = global_user_state.get_cluster_from_name(cluster_name) + handle_before_refresh = None if record is None else record['handle'] + status_before_refresh = None if record is None else record['status'] + prev_cluster_status, handle = ( backend_utils.refresh_cluster_status_handle( - cluster_name, acquire_per_cluster_status_lock=False)) + cluster_name, + # We force refresh for the init status to determine the actual + # state of a previous cluster in INIT state. + # + # This is important for the case, where an existing cluster is + # transitioned into INIT state due to key interruption during + # launching, with the following steps: + # (1) launch, after answering prompt immediately ctrl-c; + # (2) launch again. + # If we don't refresh the state of the cluster and reset it back + # to STOPPED, our failover logic will consider it as an abnormal + # cluster after hitting resources capacity limit on the cloud, + # and will start failover. This is not desired, because the user + # may want to keep the data on the disk of that cluster. + force_refresh_statuses={status_lib.ClusterStatus.INIT}, + acquire_per_cluster_status_lock=False, + )) if prev_cluster_status is not None: assert handle is not None # Cluster already exists. @@ -3784,10 +3802,6 @@ def _check_existing_cluster( task_cloud.check_cluster_name_is_valid(cluster_name) if to_provision is None: - logger.info( - f'The cluster {cluster_name!r} was autodowned or manually ' - 'terminated on the cloud console. Using the same resources ' - 'as the previously terminated one to provision a new cluster.') # The cluster is recently terminated either by autostop or manually # terminated on the cloud. We should use the previously terminated # resources to provision the cluster. @@ -3796,6 +3810,16 @@ def _check_existing_cluster( f'Trying to launch cluster {cluster_name!r} recently ' 'terminated on the cloud, but the handle is not a ' f'CloudVmRayResourceHandle ({handle_before_refresh}).') + status_before_refresh_str = None + if status_before_refresh is not None: + status_before_refresh_str = status_before_refresh.value + + logger.info( + f'The cluster {cluster_name!r} (status: ' + f'{status_before_refresh_str}) was not found on the cloud: it ' + 'may be autodowned, manually terminated, or its launch never ' + 'succeeded. Provisioning a new cluster by using the same ' + 'resources as its original launch.') to_provision = handle_before_refresh.launched_resources self.check_resources_fit_cluster(handle_before_refresh, task) diff --git a/sky/spot/controller.py b/sky/spot/controller.py index 1f4df44585e..268c8cee3dd 100644 --- a/sky/spot/controller.py +++ b/sky/spot/controller.py @@ -201,7 +201,8 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: # determine whether the cluster is preempted. (cluster_status, handle) = backend_utils.refresh_cluster_status_handle( - cluster_name, force_refresh=True) + cluster_name, + force_refresh_statuses=set(status_lib.ClusterStatus)) if cluster_status != status_lib.ClusterStatus.UP: # The cluster is (partially) preempted. It can be down, INIT diff --git a/sky/spot/recovery_strategy.py b/sky/spot/recovery_strategy.py index fe2f05eda07..c2b1caa31ba 100644 --- a/sky/spot/recovery_strategy.py +++ b/sky/spot/recovery_strategy.py @@ -182,7 +182,8 @@ def _wait_until_job_starts_on_cluster(self) -> Optional[float]: try: cluster_status, _ = ( backend_utils.refresh_cluster_status_handle( - self.cluster_name, force_refresh=True)) + self.cluster_name, + force_refresh_statuses=set(status_lib.ClusterStatus))) except Exception as e: # pylint: disable=broad-except # If any unexpected error happens, retry the job checking # loop. diff --git a/sky/spot/spot_utils.py b/sky/spot/spot_utils.py index 678b6db42dd..f9ce806b57d 100644 --- a/sky/spot/spot_utils.py +++ b/sky/spot/spot_utils.py @@ -749,13 +749,14 @@ def is_spot_controller_up( identity. """ try: - # Set force_refresh=False to make sure the refresh only happens when the - # controller is INIT/UP. This optimization avoids unnecessary costly - # refresh when the controller is already stopped. This optimization is - # based on the assumption that the user will not start the controller - # manually from the cloud console. + # Set force_refresh_statuses=None to make sure the refresh only happens + # when the controller is INIT/UP (triggered in these statuses as the + # autostop is always set for spot controller). This optimization avoids + # unnecessary costly refresh when the controller is already stopped. + # This optimization is based on the assumption that the user will not + # start the controller manually from the cloud console. controller_status, handle = backend_utils.refresh_cluster_status_handle( - SPOT_CONTROLLER_NAME, force_refresh=False) + SPOT_CONTROLLER_NAME, force_refresh_statuses=None) except exceptions.ClusterStatusFetchingError as e: # We do not catch the exceptions related to the cluster owner identity # mismatch, please refer to the comment in diff --git a/tests/test_spot.py b/tests/test_spot.py index 9c3b89a747f..d0f36dc52a0 100644 --- a/tests/test_spot.py +++ b/tests/test_spot.py @@ -103,7 +103,7 @@ def test_down_spot_controller(self, _mock_cluster_state, monkeypatch): def mock_cluster_refresh_up( cluster_name: str, *, - force_refresh: bool = False, + force_refresh_statuses: bool = False, acquire_per_cluster_status_lock: bool = True, ): record = global_user_state.get_cluster_from_name(cluster_name)