-
Notifications
You must be signed in to change notification settings - Fork 549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spot] Let cancel interrupt the spot job #1414
Changes from 11 commits
c3a8599
169e8b5
3e8bac6
e82d89c
c081c60
3441b43
88f4db5
331fa32
25c568c
71a253d
5253d5d
6e9ba0c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
import colorama | ||
import filelock | ||
import ray | ||
|
||
import sky | ||
from sky import exceptions | ||
|
@@ -28,9 +29,9 @@ class SpotController: | |
|
||
def __init__(self, job_id: int, task_yaml: str, | ||
retry_until_up: bool) -> None: | ||
self._job_id = job_id | ||
self._task_name = pathlib.Path(task_yaml).stem | ||
self._task = sky.Task.from_yaml(task_yaml) | ||
self.job_id = job_id | ||
self.task_name = pathlib.Path(task_yaml).stem | ||
self.task = sky.Task.from_yaml(task_yaml) | ||
|
||
self._retry_until_up = retry_until_up | ||
# TODO(zhwu): this assumes the specific backend. | ||
|
@@ -39,145 +40,77 @@ def __init__(self, job_id: int, task_yaml: str, | |
# Add a unique identifier to the task environment variables, so that | ||
# the user can have the same id for multiple recoveries. | ||
# Example value: sky-2022-10-04-22-46-52-467694_id-17 | ||
task_envs = self._task.envs or {} | ||
task_envs = self.task.envs or {} | ||
job_id_env_var = common_utils.get_global_job_id( | ||
self.backend.run_timestamp, 'spot', self._job_id) | ||
self.backend.run_timestamp, 'spot', self.job_id) | ||
task_envs[constants.JOB_ID_ENV_VAR] = job_id_env_var | ||
self._task.set_envs(task_envs) | ||
self.task.set_envs(task_envs) | ||
|
||
spot_state.set_submitted( | ||
self._job_id, | ||
self._task_name, | ||
self.job_id, | ||
self.task_name, | ||
self.backend.run_timestamp, | ||
resources_str=backend_utils.get_task_resources_str(self._task)) | ||
resources_str=backend_utils.get_task_resources_str(self.task)) | ||
logger.info(f'Submitted spot job; SKYPILOT_JOB_ID: {job_id_env_var}') | ||
self._cluster_name = spot_utils.generate_spot_cluster_name( | ||
self._task_name, self._job_id) | ||
self._strategy_executor = recovery_strategy.StrategyExecutor.make( | ||
self._cluster_name, self.backend, self._task, retry_until_up, | ||
self._handle_signal) | ||
|
||
def _run(self): | ||
"""Busy loop monitoring spot cluster status and handling recovery.""" | ||
logger.info(f'Started monitoring spot task {self._task_name} ' | ||
f'(id: {self._job_id})') | ||
spot_state.set_starting(self._job_id) | ||
start_at = self._strategy_executor.launch() | ||
|
||
spot_state.set_started(self._job_id, start_time=start_at) | ||
while True: | ||
time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS) | ||
# Handle the signal if it is sent by the user. | ||
self._handle_signal() | ||
|
||
# Check the network connection to avoid false alarm for job failure. | ||
# Network glitch was observed even in the VM. | ||
try: | ||
backend_utils.check_network_connection() | ||
except exceptions.NetworkError: | ||
logger.info( | ||
'Network is not available. Retrying again in ' | ||
f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.') | ||
continue | ||
|
||
# NOTE: we do not check cluster status first because race condition | ||
# can occur, i.e. cluster can be down during the job status check. | ||
job_status = spot_utils.get_job_status(self.backend, | ||
self._cluster_name) | ||
|
||
if job_status is not None and not job_status.is_terminal(): | ||
need_recovery = False | ||
if self._task.num_nodes > 1: | ||
# Check the cluster status for multi-node jobs, since the | ||
# job may not be set to FAILED immediately when only some | ||
# of the nodes are preempted. | ||
(cluster_status, | ||
handle) = backend_utils.refresh_cluster_status_handle( | ||
self._cluster_name, force_refresh=True) | ||
if cluster_status != global_user_state.ClusterStatus.UP: | ||
# recover the cluster if it is not up. | ||
logger.info(f'Cluster status {cluster_status.value}. ' | ||
'Recovering...') | ||
need_recovery = True | ||
if not need_recovery: | ||
# The job and cluster are healthy, continue to monitor the | ||
# job status. | ||
continue | ||
|
||
if job_status == job_lib.JobStatus.SUCCEEDED: | ||
end_time = spot_utils.get_job_timestamp(self.backend, | ||
self._cluster_name, | ||
get_end_time=True) | ||
# The job is done. | ||
spot_state.set_succeeded(self._job_id, end_time=end_time) | ||
break | ||
|
||
if job_status == job_lib.JobStatus.FAILED: | ||
# Check the status of the spot cluster. If it is not UP, | ||
# the cluster is preempted. | ||
(cluster_status, | ||
handle) = backend_utils.refresh_cluster_status_handle( | ||
self._cluster_name, force_refresh=True) | ||
if cluster_status == global_user_state.ClusterStatus.UP: | ||
# The user code has probably crashed. | ||
end_time = spot_utils.get_job_timestamp(self.backend, | ||
self._cluster_name, | ||
get_end_time=True) | ||
logger.info( | ||
'The user job failed. Please check the logs below.\n' | ||
f'== Logs of the user job (ID: {self._job_id}) ==\n') | ||
self.backend.tail_logs(handle, | ||
None, | ||
spot_job_id=self._job_id) | ||
logger.info(f'\n== End of logs (ID: {self._job_id}) ==') | ||
spot_state.set_failed( | ||
self._job_id, | ||
failure_type=spot_state.SpotStatus.FAILED, | ||
end_time=end_time) | ||
break | ||
# cluster can be down, INIT or STOPPED, based on the interruption | ||
# behavior of the cloud. | ||
# Failed to connect to the cluster or the cluster is partially down. | ||
# job_status is None or job_status == job_lib.JobStatus.FAILED | ||
logger.info('The cluster is preempted.') | ||
spot_state.set_recovering(self._job_id) | ||
recovered_time = self._strategy_executor.recover() | ||
spot_state.set_recovered(self._job_id, | ||
recovered_time=recovered_time) | ||
self.cluster_name = spot_utils.generate_spot_cluster_name( | ||
self.task_name, self.job_id) | ||
self.strategy_executor = recovery_strategy.StrategyExecutor.make( | ||
self.cluster_name, self.backend, self.task, retry_until_up) | ||
|
||
def start(self): | ||
"""Start the controller.""" | ||
try: | ||
self._run() | ||
self._handle_signal() | ||
controller_task = _controller_run.remote(self) | ||
# Signal can interrupt the underlying controller process. | ||
ready, _ = ray.wait([controller_task], timeout=0) | ||
while not ready: | ||
try: | ||
self._handle_signal() | ||
except exceptions.SpotUserCancelledError as e: | ||
logger.info(f'Cancelling spot job {self.job_id}...') | ||
try: | ||
ray.cancel(controller_task) | ||
ray.get(controller_task) | ||
except ray.exceptions.RayTaskError: | ||
Michaelvll marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# When the controller task is cancelled, it will raise | ||
# ray.exceptions.RayTaskError, which can be ignored, | ||
# since the SpotUserCancelledError will be raised and | ||
# handled later. | ||
pass | ||
raise e | ||
ready, _ = ray.wait([controller_task], timeout=1) | ||
# Need this to get the exception from the controller task. | ||
ray.get(controller_task) | ||
except exceptions.SpotUserCancelledError as e: | ||
logger.info(e) | ||
spot_state.set_cancelled(self._job_id) | ||
spot_state.set_cancelled(self.job_id) | ||
except exceptions.ResourcesUnavailableError as e: | ||
logger.error(f'Resources unavailable: {colorama.Fore.RED}{e}' | ||
f'{colorama.Style.RESET_ALL}') | ||
spot_state.set_failed( | ||
self._job_id, | ||
self.job_id, | ||
failure_type=spot_state.SpotStatus.FAILED_NO_RESOURCE) | ||
except (Exception, SystemExit) as e: # pylint: disable=broad-except | ||
logger.error(traceback.format_exc()) | ||
logger.error(f'Unexpected error occurred: {type(e).__name__}: {e}') | ||
finally: | ||
self._strategy_executor.terminate_cluster() | ||
job_status = spot_state.get_status(self._job_id) | ||
self.strategy_executor.terminate_cluster() | ||
job_status = spot_state.get_status(self.job_id) | ||
# The job can be non-terminal if the controller exited abnormally, | ||
# e.g. failed to launch cluster after reaching the MAX_RETRY. | ||
if not job_status.is_terminal(): | ||
spot_state.set_failed( | ||
self._job_id, | ||
self.job_id, | ||
failure_type=spot_state.SpotStatus.FAILED_CONTROLLER) | ||
|
||
# Clean up Storages with persistent=False. | ||
self.backend.teardown_ephemeral_storage(self._task) | ||
self.backend.teardown_ephemeral_storage(self.task) | ||
|
||
def _handle_signal(self): | ||
"""Handle the signal if the user sent it.""" | ||
signal_file = pathlib.Path( | ||
spot_utils.SIGNAL_FILE_PREFIX.format(self._job_id)) | ||
spot_utils.SIGNAL_FILE_PREFIX.format(self.job_id)) | ||
signal = None | ||
if signal_file.exists(): | ||
# Filelock is needed to prevent race condition with concurrent | ||
|
@@ -200,7 +133,97 @@ def _handle_signal(self): | |
raise RuntimeError(f'Unknown SkyPilot signal received: {signal.value}.') | ||
|
||
|
||
@ray.remote(num_cpus=0) | ||
def _controller_run(spot_controller: SpotController): | ||
"""Busy loop monitoring spot cluster status and handling recovery.""" | ||
logger.info(f'Started monitoring spot task {spot_controller.task_name} ' | ||
f'(id: {spot_controller.job_id})') | ||
spot_state.set_starting(spot_controller.job_id) | ||
start_at = spot_controller.strategy_executor.launch() | ||
|
||
spot_state.set_started(spot_controller.job_id, start_time=start_at) | ||
while True: | ||
time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS) | ||
|
||
# Check the network connection to avoid false alarm for job failure. | ||
# Network glitch was observed even in the VM. | ||
try: | ||
backend_utils.check_network_connection() | ||
except exceptions.NetworkError: | ||
logger.info('Network is not available. Retrying again in ' | ||
f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.') | ||
continue | ||
|
||
# NOTE: we do not check cluster status first because race condition | ||
# can occur, i.e. cluster can be down during the job status check. | ||
job_status = spot_utils.get_job_status(spot_controller.backend, | ||
spot_controller.cluster_name) | ||
|
||
if job_status is not None and not job_status.is_terminal(): | ||
need_recovery = False | ||
if spot_controller.task.num_nodes > 1: | ||
# Check the cluster status for multi-node jobs, since the | ||
# job may not be set to FAILED immediately when only some | ||
# of the nodes are preempted. | ||
(cluster_status, | ||
handle) = backend_utils.refresh_cluster_status_handle( | ||
spot_controller.cluster_name, force_refresh=True) | ||
if cluster_status != global_user_state.ClusterStatus.UP: | ||
# recover the cluster if it is not up. | ||
logger.info(f'Cluster status {cluster_status.value}. ' | ||
'Recovering...') | ||
need_recovery = True | ||
if not need_recovery: | ||
# The job and cluster are healthy, continue to monitor the | ||
# job status. | ||
continue | ||
|
||
if job_status == job_lib.JobStatus.SUCCEEDED: | ||
end_time = spot_utils.get_job_timestamp( | ||
spot_controller.backend, | ||
spot_controller.cluster_name, | ||
get_end_time=True) | ||
# The job is done. | ||
spot_state.set_succeeded(spot_controller.job_id, end_time=end_time) | ||
break | ||
|
||
if job_status == job_lib.JobStatus.FAILED: | ||
# Check the status of the spot cluster. If it is not UP, | ||
# the cluster is preempted. | ||
(cluster_status, | ||
handle) = backend_utils.refresh_cluster_status_handle( | ||
spot_controller.cluster_name, force_refresh=True) | ||
if cluster_status == global_user_state.ClusterStatus.UP: | ||
# The user code has probably crashed. | ||
end_time = spot_utils.get_job_timestamp( | ||
spot_controller.backend, | ||
spot_controller.cluster_name, | ||
get_end_time=True) | ||
logger.info( | ||
'The user job failed. Please check the logs below.\n' | ||
'== Logs of the user job (ID: ' | ||
f'{spot_controller.job_id}) ==\n') | ||
spot_controller.backend.tail_logs( | ||
handle, None, spot_job_id=spot_controller.job_id) | ||
logger.info( | ||
f'\n== End of logs (ID: {spot_controller.job_id}) ==') | ||
spot_state.set_failed(spot_controller.job_id, | ||
failure_type=spot_state.SpotStatus.FAILED, | ||
end_time=end_time) | ||
break | ||
# cluster can be down, INIT or STOPPED, based on the interruption | ||
# behavior of the cloud. | ||
# Failed to connect to the cluster or the cluster is partially down. | ||
# job_status is None or job_status == job_lib.JobStatus.FAILED | ||
logger.info('The cluster is preempted.') | ||
spot_state.set_recovering(spot_controller.job_id) | ||
recovered_time = spot_controller.strategy_executor.recover() | ||
spot_state.set_recovered(spot_controller.job_id, | ||
recovered_time=recovered_time) | ||
|
||
|
||
if __name__ == '__main__': | ||
ray.init('auto') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about using Pros:
Cons:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! I think the second Pro holds, but for the first point, ray has changed the behavior for
That is to say using |
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--job-id', | ||
required=True, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity - any particular reason to use Ray? Could this have been implemented with
multiprocessing
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh, there is no particular reason for implementing it with
ray
. The only reason is that I am more familiar with the busy loop withray.wait
, and our VMs already have a ray cluster running in the background.Do you think it would be better to implement it with
multiprocessing
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think multiprocessing might be faster, but ray is also fine if its faster for us to implement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just tried out
multiprocessing
,concurrent.futures
andasyncio
, and seems they all have the functionality to support running the process in parallel, but with drawbacks:multiprocessing.Process
: We can start the process and check the liveness of the process withis_alive
. We can cancel the process withos.kill(p.pid, signal.SIGINT)
. However, we cannot easily catch the exceptions happening inside the process from the main process. Additional codes are needed to do so, as mentioned here.concurrent.futures
: We can start the process with theExecutor.submit
and catch the exceptions from the main process, but the future object does not have a way to stop the running process, unlikeray.cancel
.asyncio.create_task
: Based on their doc, the task is not guaranteed to be cancelled, but we need that guarantee to cancel the_controller_run
function when theCANCEL
signal received.Based on the drawback above, I think it might be fine to use
ray
to handle the async and future as we did in the code. Wdyt @romilbhardwaj ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah for
multiprocessing
, I was thinking of usingmultiprocessing.Pool
since it provides nice exception handling, but turns out its not easy to cancel running processes in it.I think it's fine to use
ray
for now. Thanks for the extensive investigation into different methods to do this!