Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spot] Let cancel interrupt the spot job #1414

Merged
merged 12 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 129 additions & 110 deletions sky/spot/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import colorama
import filelock
import ray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity - any particular reason to use Ray? Could this have been implemented with multiprocessing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, there is no particular reason for implementing it with ray. The only reason is that I am more familiar with the busy loop with ray.wait, and our VMs already have a ray cluster running in the background.

Do you think it would be better to implement it with multiprocessing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think multiprocessing might be faster, but ray is also fine if its faster for us to implement.

Copy link
Collaborator Author

@Michaelvll Michaelvll Nov 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tried out multiprocessing, concurrent.futures and asyncio, and seems they all have the functionality to support running the process in parallel, but with drawbacks:

  1. multiprocessing.Process: We can start the process and check the liveness of the process with is_alive. We can cancel the process with os.kill(p.pid, signal.SIGINT). However, we cannot easily catch the exceptions happening inside the process from the main process. Additional codes are needed to do so, as mentioned here.
  2. concurrent.futures: We can start the process with the Executor.submit and catch the exceptions from the main process, but the future object does not have a way to stop the running process, unlike ray.cancel.
  3. asyncio.create_task: Based on their doc, the task is not guaranteed to be cancelled, but we need that guarantee to cancel the _controller_run function when the CANCEL signal received.

Based on the drawback above, I think it might be fine to use ray to handle the async and future as we did in the code. Wdyt @romilbhardwaj ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah for multiprocessing, I was thinking of using multiprocessing.Pool since it provides nice exception handling, but turns out its not easy to cancel running processes in it.

I think it's fine to use ray for now. Thanks for the extensive investigation into different methods to do this!


import sky
from sky import exceptions
Expand All @@ -28,9 +29,9 @@ class SpotController:

def __init__(self, job_id: int, task_yaml: str,
retry_until_up: bool) -> None:
self._job_id = job_id
self._task_name = pathlib.Path(task_yaml).stem
self._task = sky.Task.from_yaml(task_yaml)
self.job_id = job_id
self.task_name = pathlib.Path(task_yaml).stem
self.task = sky.Task.from_yaml(task_yaml)

self._retry_until_up = retry_until_up
# TODO(zhwu): this assumes the specific backend.
Expand All @@ -39,145 +40,73 @@ def __init__(self, job_id: int, task_yaml: str,
# Add a unique identifier to the task environment variables, so that
# the user can have the same id for multiple recoveries.
# Example value: sky-2022-10-04-22-46-52-467694_id-17
task_envs = self._task.envs or {}
task_envs = self.task.envs or {}
job_id_env_var = common_utils.get_global_job_id(
self.backend.run_timestamp, 'spot', self._job_id)
self.backend.run_timestamp, 'spot', self.job_id)
task_envs[constants.JOB_ID_ENV_VAR] = job_id_env_var
self._task.set_envs(task_envs)
self.task.set_envs(task_envs)

spot_state.set_submitted(
self._job_id,
self._task_name,
self.job_id,
self.task_name,
self.backend.run_timestamp,
resources_str=backend_utils.get_task_resources_str(self._task))
resources_str=backend_utils.get_task_resources_str(self.task))
logger.info(f'Submitted spot job; SKYPILOT_JOB_ID: {job_id_env_var}')
self._cluster_name = spot_utils.generate_spot_cluster_name(
self._task_name, self._job_id)
self._strategy_executor = recovery_strategy.StrategyExecutor.make(
self._cluster_name, self.backend, self._task, retry_until_up,
self._handle_signal)

def _run(self):
"""Busy loop monitoring spot cluster status and handling recovery."""
logger.info(f'Started monitoring spot task {self._task_name} '
f'(id: {self._job_id})')
spot_state.set_starting(self._job_id)
start_at = self._strategy_executor.launch()

spot_state.set_started(self._job_id, start_time=start_at)
while True:
time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS)
# Handle the signal if it is sent by the user.
self._handle_signal()

# Check the network connection to avoid false alarm for job failure.
# Network glitch was observed even in the VM.
try:
backend_utils.check_network_connection()
except exceptions.NetworkError:
logger.info(
'Network is not available. Retrying again in '
f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
continue

# NOTE: we do not check cluster status first because race condition
# can occur, i.e. cluster can be down during the job status check.
job_status = spot_utils.get_job_status(self.backend,
self._cluster_name)

if job_status is not None and not job_status.is_terminal():
need_recovery = False
if self._task.num_nodes > 1:
# Check the cluster status for multi-node jobs, since the
# job may not be set to FAILED immediately when only some
# of the nodes are preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
self._cluster_name, force_refresh=True)
if cluster_status != global_user_state.ClusterStatus.UP:
# recover the cluster if it is not up.
logger.info(f'Cluster status {cluster_status.value}. '
'Recovering...')
need_recovery = True
if not need_recovery:
# The job and cluster are healthy, continue to monitor the
# job status.
continue

if job_status == job_lib.JobStatus.SUCCEEDED:
end_time = spot_utils.get_job_timestamp(self.backend,
self._cluster_name,
get_end_time=True)
# The job is done.
spot_state.set_succeeded(self._job_id, end_time=end_time)
break

if job_status == job_lib.JobStatus.FAILED:
# Check the status of the spot cluster. If it is not UP,
# the cluster is preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
self._cluster_name, force_refresh=True)
if cluster_status == global_user_state.ClusterStatus.UP:
# The user code has probably crashed.
end_time = spot_utils.get_job_timestamp(self.backend,
self._cluster_name,
get_end_time=True)
logger.info(
'The user job failed. Please check the logs below.\n'
f'== Logs of the user job (ID: {self._job_id}) ==\n')
self.backend.tail_logs(handle,
None,
spot_job_id=self._job_id)
logger.info(f'\n== End of logs (ID: {self._job_id}) ==')
spot_state.set_failed(
self._job_id,
failure_type=spot_state.SpotStatus.FAILED,
end_time=end_time)
break
# cluster can be down, INIT or STOPPED, based on the interruption
# behavior of the cloud.
# Failed to connect to the cluster or the cluster is partially down.
# job_status is None or job_status == job_lib.JobStatus.FAILED
logger.info('The cluster is preempted.')
spot_state.set_recovering(self._job_id)
recovered_time = self._strategy_executor.recover()
spot_state.set_recovered(self._job_id,
recovered_time=recovered_time)
self.cluster_name = spot_utils.generate_spot_cluster_name(
self.task_name, self.job_id)
self.strategy_executor = recovery_strategy.StrategyExecutor.make(
self.cluster_name, self.backend, self.task, retry_until_up)

def start(self):
"""Start the controller."""
try:
self._run()
self._handle_signal()
controller_task = _controller_run.remote(self)
# Signal can interrupt the underlying controller process.
ready, _ = ray.wait([controller_task], timeout=0)
while not ready:
try:
self._handle_signal()
except exceptions.SpotUserCancelledError as e:
logger.info('Cancelling...')
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
try:
ray.cancel(controller_task)
ray.get(controller_task)
except ray.exceptions.RayTaskError:
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
pass
raise e
ready, _ = ray.wait([controller_task], timeout=1)
# Need this to get the exception from the controller task.
ray.get(controller_task)
except exceptions.SpotUserCancelledError as e:
logger.info(e)
spot_state.set_cancelled(self._job_id)
spot_state.set_cancelled(self.job_id)
except exceptions.ResourcesUnavailableError as e:
logger.error(f'Resources unavailable: {colorama.Fore.RED}{e}'
f'{colorama.Style.RESET_ALL}')
spot_state.set_failed(
self._job_id,
self.job_id,
failure_type=spot_state.SpotStatus.FAILED_NO_RESOURCE)
except (Exception, SystemExit) as e: # pylint: disable=broad-except
logger.error(traceback.format_exc())
logger.error(f'Unexpected error occurred: {type(e).__name__}: {e}')
finally:
self._strategy_executor.terminate_cluster()
job_status = spot_state.get_status(self._job_id)
self.strategy_executor.terminate_cluster()
job_status = spot_state.get_status(self.job_id)
# The job can be non-terminal if the controller exited abnormally,
# e.g. failed to launch cluster after reaching the MAX_RETRY.
if not job_status.is_terminal():
spot_state.set_failed(
self._job_id,
self.job_id,
failure_type=spot_state.SpotStatus.FAILED_CONTROLLER)

# Clean up Storages with persistent=False.
self.backend.teardown_ephemeral_storage(self._task)
self.backend.teardown_ephemeral_storage(self.task)

def _handle_signal(self):
"""Handle the signal if the user sent it."""
signal_file = pathlib.Path(
spot_utils.SIGNAL_FILE_PREFIX.format(self._job_id))
spot_utils.SIGNAL_FILE_PREFIX.format(self.job_id))
signal = None
if signal_file.exists():
# Filelock is needed to prevent race condition with concurrent
Expand All @@ -200,7 +129,97 @@ def _handle_signal(self):
raise RuntimeError(f'Unknown SkyPilot signal received: {signal.value}.')


@ray.remote(num_cpus=0)
def _controller_run(spot_controller: SpotController):
"""Busy loop monitoring spot cluster status and handling recovery."""
logger.info(f'Started monitoring spot task {spot_controller.task_name} '
f'(id: {spot_controller.job_id})')
spot_state.set_starting(spot_controller.job_id)
start_at = spot_controller.strategy_executor.launch()

spot_state.set_started(spot_controller.job_id, start_time=start_at)
while True:
time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS)

# Check the network connection to avoid false alarm for job failure.
# Network glitch was observed even in the VM.
try:
backend_utils.check_network_connection()
except exceptions.NetworkError:
logger.info('Network is not available. Retrying again in '
f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
continue

# NOTE: we do not check cluster status first because race condition
# can occur, i.e. cluster can be down during the job status check.
job_status = spot_utils.get_job_status(spot_controller.backend,
spot_controller.cluster_name)

if job_status is not None and not job_status.is_terminal():
need_recovery = False
if spot_controller.task.num_nodes > 1:
# Check the cluster status for multi-node jobs, since the
# job may not be set to FAILED immediately when only some
# of the nodes are preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
spot_controller.cluster_name, force_refresh=True)
if cluster_status != global_user_state.ClusterStatus.UP:
# recover the cluster if it is not up.
logger.info(f'Cluster status {cluster_status.value}. '
'Recovering...')
need_recovery = True
if not need_recovery:
# The job and cluster are healthy, continue to monitor the
# job status.
continue

if job_status == job_lib.JobStatus.SUCCEEDED:
end_time = spot_utils.get_job_timestamp(
spot_controller.backend,
spot_controller.cluster_name,
get_end_time=True)
# The job is done.
spot_state.set_succeeded(spot_controller.job_id, end_time=end_time)
break

if job_status == job_lib.JobStatus.FAILED:
# Check the status of the spot cluster. If it is not UP,
# the cluster is preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
spot_controller.cluster_name, force_refresh=True)
if cluster_status == global_user_state.ClusterStatus.UP:
# The user code has probably crashed.
end_time = spot_utils.get_job_timestamp(
spot_controller.backend,
spot_controller.cluster_name,
get_end_time=True)
logger.info(
'The user job failed. Please check the logs below.\n'
'== Logs of the user job (ID: '
f'{spot_controller.job_id}) ==\n')
spot_controller.backend.tail_logs(
handle, None, spot_job_id=spot_controller.job_id)
logger.info(
f'\n== End of logs (ID: {spot_controller.job_id}) ==')
spot_state.set_failed(spot_controller.job_id,
failure_type=spot_state.SpotStatus.FAILED,
end_time=end_time)
break
# cluster can be down, INIT or STOPPED, based on the interruption
# behavior of the cloud.
# Failed to connect to the cluster or the cluster is partially down.
# job_status is None or job_status == job_lib.JobStatus.FAILED
logger.info('The cluster is preempted.')
spot_state.set_recovering(spot_controller.job_id)
recovered_time = spot_controller.strategy_executor.recover()
spot_state.set_recovered(spot_controller.job_id,
recovered_time=recovered_time)


if __name__ == '__main__':
ray.init('auto')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about using ray.init() here instead of ray.init('auto')?

Pros:

  • Provides isolation - if something goes wrong in the ray cluster because of this script it won't affect the main ray cluster
  • Removes dependency on having ray running, if we ever need to run this controller in isolation in the future.

Cons:

  • Added resource cost/overheads

Copy link
Collaborator Author

@Michaelvll Michaelvll Nov 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I think the second Pro holds, but for the first point, ray has changed the behavior for ray.init() (since ray>=2.0) to the following (doc):

This will autodetect an existing Ray cluster or start a new Ray instance if no existing cluster is found

That is to say using ray.init() instead of ray.init('auto') will not have the Con and the first Pro mentioned above.

parser = argparse.ArgumentParser()
parser.add_argument('--job-id',
required=True,
Expand Down
16 changes: 4 additions & 12 deletions sky/spot/recovery_strategy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The strategy to handle launching/recovery/termination of spot clusters."""
import time
import typing
from typing import Callable, Optional
from typing import Optional

import sky
from sky import exceptions
Expand Down Expand Up @@ -34,24 +34,20 @@ class StrategyExecutor:
RETRY_INIT_GAP_SECONDS = 60

def __init__(self, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool,
signal_handler: Callable) -> None:
task: 'task_lib.Task', retry_until_up: bool) -> None:
"""Initialize the strategy executor.

Args:
cluster_name: The name of the cluster.
backend: The backend to use. Only CloudVMRayBackend is supported.
task: The task to execute.
retry_until_up: Whether to retry until the cluster is up.
signal_handler: The signal handler that will raise an exception if a
SkyPilot signal is received.
"""
self.dag = sky.Dag()
self.dag.add(task)
self.cluster_name = cluster_name
self.backend = backend
self.retry_until_up = retry_until_up
self.signal_handler = signal_handler

def __init_subclass__(cls, name: str, default: bool = False):
SPOT_STRATEGIES[name] = cls
Expand All @@ -63,8 +59,7 @@ def __init_subclass__(cls, name: str, default: bool = False):

@classmethod
def make(cls, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool,
signal_handler: Callable) -> 'StrategyExecutor':
task: 'task_lib.Task', retry_until_up: bool) -> 'StrategyExecutor':
"""Create a strategy from a task."""
resources = task.resources
assert len(resources) == 1, 'Only one resource is supported.'
Expand All @@ -77,7 +72,7 @@ def make(cls, cluster_name: str, backend: 'backends.Backend',
# will be handled by the strategy class.
task.set_resources({resources.copy(spot_recovery=None)})
return SPOT_STRATEGIES[spot_recovery](cluster_name, backend, task,
retry_until_up, signal_handler)
retry_until_up)

def launch(self) -> Optional[float]:
"""Launch the spot cluster for the first time.
Expand Down Expand Up @@ -147,9 +142,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
backoff = common_utils.Backoff(self.RETRY_INIT_GAP_SECONDS)
while True:
retry_cnt += 1
# Check the signal every time to be more responsive to user
# signals, such as Cancel.
self.signal_handler()
retry_launch = False
exception = None
try:
Expand Down
3 changes: 1 addition & 2 deletions sky/spot/spot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
cancelled_job_ids_str = ', '.join(map(str, cancelled_job_ids))
identity_str = f'Jobs with IDs {cancelled_job_ids_str} are'

return (f'{identity_str} scheduled to be cancelled within '
f'{JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
return f'{identity_str} scheduled to be cancelle.'
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved


def cancel_job_by_name(job_name: str) -> str:
Expand Down
Loading