diff --git a/docs/source/getting-started/installation.rst b/docs/source/getting-started/installation.rst index cf6115ee9e8..69303a582e2 100644 --- a/docs/source/getting-started/installation.rst +++ b/docs/source/getting-started/installation.rst @@ -264,6 +264,7 @@ The :code:`~/.oci/config` file should contain the following fields: fingerprint=aa:bb:cc:dd:ee:ff:gg:hh:ii:jj:kk:ll:mm:nn:oo:pp tenancy=ocid1.tenancy.oc1..aaaaaaaa region=us-sanjose-1 + # Note that we should avoid using full home path for the key_file configuration, e.g. use ~/.oci instead of /home/username/.oci key_file=~/.oci/oci_api_key.pem diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index 73a27355ccf..b5e854cceac 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -2,7 +2,7 @@ import copy import json import time -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from sky import exceptions from sky import sky_logging @@ -23,6 +23,8 @@ POLL_INTERVAL = 2 _TIMEOUT_FOR_POD_TERMINATION = 60 # 1 minutes +_MAX_RETRIES = 3 +NUM_THREADS = subprocess_utils.get_parallel_threads() * 2 logger = sky_logging.init_logger(__name__) TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' @@ -303,6 +305,33 @@ def _check_init_containers(pod): time.sleep(1) +def _run_function_with_retries(func: Callable, + operation_name: str, + max_retries: int = _MAX_RETRIES, + retry_delay: int = 5) -> Any: + """Runs a function with retries on Kubernetes errors. + + Args: + func: Function to retry + operation_name: Name of the operation for logging + max_retries: Maximum number of retry attempts + retry_delay: Delay between retries in seconds + + Raises: + The last exception encountered if all retries fail. + """ + for attempt in range(max_retries + 1): + try: + return func() + except config_lib.KubernetesError: + if attempt < max_retries: + logger.warning(f'Failed to {operation_name} - ' + f'retrying in {retry_delay} seconds.') + time.sleep(retry_delay) + else: + raise + + def _set_env_vars_in_pods(namespace: str, context: Optional[str], new_pods: List): """Setting environment variables in pods. @@ -322,14 +351,27 @@ def _set_env_vars_in_pods(namespace: str, context: Optional[str], """ set_k8s_env_var_cmd = docker_utils.SETUP_ENV_VARS_CMD - for new_pod in new_pods: + def _set_env_vars_thread(new_pod): + pod_name = new_pod.metadata.name + logger.info(f'{"-"*20}Start: Set up env vars in pod {pod_name!r} ' + f'{"-"*20}') runner = command_runner.KubernetesCommandRunner( - ((namespace, context), new_pod.metadata.name)) - rc, stdout, _ = runner.run(set_k8s_env_var_cmd, - require_outputs=True, - stream_logs=False) - _raise_command_running_error('set env vars', set_k8s_env_var_cmd, - new_pod.metadata.name, rc, stdout) + ((namespace, context), pod_name)) + + def _run_env_vars_cmd(): + rc, stdout, _ = runner.run(set_k8s_env_var_cmd, + require_outputs=True, + stream_logs=False) + _raise_command_running_error('set env vars', set_k8s_env_var_cmd, + pod_name, rc, stdout) + + _run_function_with_retries(_run_env_vars_cmd, + f'set env vars in pod {pod_name}') + logger.info(f'{"-"*20}End: Set up env vars in pod {pod_name!r} ' + f'{"-"*20}') + + subprocess_utils.run_in_parallel(_set_env_vars_thread, new_pods, + NUM_THREADS) def _check_user_privilege(namespace: str, context: Optional[str], @@ -349,23 +391,37 @@ def _check_user_privilege(namespace: str, context: Optional[str], ' fi; ' 'fi') - for new_node in new_nodes: - runner = command_runner.KubernetesCommandRunner( - ((namespace, context), new_node.metadata.name)) + # This check needs to run on a per-image basis, so running the check on + # any one pod is sufficient. + new_node = new_nodes[0] + pod_name = new_node.metadata.name + + runner = command_runner.KubernetesCommandRunner( + ((namespace, context), pod_name)) + logger.info(f'{"-"*20}Start: Check user privilege in pod {pod_name!r} ' + f'{"-"*20}') + + def _run_privilege_check(): rc, stdout, stderr = runner.run(check_k8s_user_sudo_cmd, require_outputs=True, separate_stderr=True, stream_logs=False) _raise_command_running_error('check user privilege', - check_k8s_user_sudo_cmd, - new_node.metadata.name, rc, + check_k8s_user_sudo_cmd, pod_name, rc, stdout + stderr) - if stdout == str(exceptions.INSUFFICIENT_PRIVILEGES_CODE): - raise config_lib.KubernetesError( - 'Insufficient system privileges detected. ' - 'Ensure the default user has root access or ' - '"sudo" is installed and the user is added to the sudoers ' - 'from the image.') + return stdout + + stdout = _run_function_with_retries( + _run_privilege_check, f'check user privilege in pod {pod_name!r}') + + if stdout == str(exceptions.INSUFFICIENT_PRIVILEGES_CODE): + raise config_lib.KubernetesError( + 'Insufficient system privileges detected. ' + 'Ensure the default user has root access or ' + '"sudo" is installed and the user is added to the sudoers ' + 'from the image.') + logger.info(f'{"-"*20}End: Check user privilege in pod {pod_name!r} ' + f'{"-"*20}') def _setup_ssh_in_pods(namespace: str, context: Optional[str], @@ -404,14 +460,19 @@ def _setup_ssh_thread(new_node): runner = command_runner.KubernetesCommandRunner( ((namespace, context), pod_name)) logger.info(f'{"-"*20}Start: Set up SSH in pod {pod_name!r} {"-"*20}') - rc, stdout, _ = runner.run(set_k8s_ssh_cmd, - require_outputs=True, - stream_logs=False) - _raise_command_running_error('setup ssh', set_k8s_ssh_cmd, pod_name, rc, - stdout) + + def _run_ssh_setup(): + rc, stdout, _ = runner.run(set_k8s_ssh_cmd, + require_outputs=True, + stream_logs=False) + _raise_command_running_error('setup ssh', set_k8s_ssh_cmd, pod_name, + rc, stdout) + + _run_function_with_retries(_run_ssh_setup, + f'setup ssh in pod {pod_name!r}') logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}') - subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes) + subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes, NUM_THREADS) def _label_pod(namespace: str, context: Optional[str], pod_name: str, @@ -767,12 +828,17 @@ def terminate_instances( def _is_head(pod) -> bool: return pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head' - for pod_name, pod in pods.items(): - logger.debug(f'Terminating instance {pod_name}: {pod}') + def _terminate_pod_thread(pod_info): + pod_name, pod = pod_info if _is_head(pod) and worker_only: - continue + return + logger.debug(f'Terminating instance {pod_name}: {pod}') _terminate_node(namespace, context, pod_name) + # Run pod termination in parallel + subprocess_utils.run_in_parallel(_terminate_pod_thread, pods.items(), + NUM_THREADS) + def get_cluster_info( region: str, diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index 7eae76040d8..25483031038 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -237,6 +237,23 @@ def _rsync( rsync_command.append(prefix_command) rsync_command += ['rsync', RSYNC_DISPLAY_OPTION] + def _get_remote_home_dir_with_retry(): + backoff = common_utils.Backoff(initial_backoff=1, + max_backoff_factor=5) + retries_left = max_retry + assert retries_left > 0, f'max_retry {max_retry} must be positive.' + while retries_left >= 0: + try: + return get_remote_home_dir() + except Exception: # pylint: disable=broad-except + if retries_left == 0: + raise + sleep_time = backoff.current_backoff() + logger.warning(f'Failed to get remote home dir ' + f'- retrying in {sleep_time} seconds.') + retries_left -= 1 + time.sleep(sleep_time) + # --filter # The source is a local path, so we need to resolve it. resolved_source = pathlib.Path(source).expanduser().resolve() @@ -261,7 +278,7 @@ def _rsync( if up: resolved_target = target if target.startswith('~'): - remote_home_dir = get_remote_home_dir() + remote_home_dir = _get_remote_home_dir_with_retry() resolved_target = target.replace('~', remote_home_dir) full_source_str = str(resolved_source) if resolved_source.is_dir(): @@ -273,7 +290,7 @@ def _rsync( else: resolved_source = source if source.startswith('~'): - remote_home_dir = get_remote_home_dir() + remote_home_dir = _get_remote_home_dir_with_retry() resolved_source = source.replace('~', remote_home_dir) rsync_command.extend([ f'{node_destination}:{resolved_source!r}', @@ -656,6 +673,8 @@ def rsync( class KubernetesCommandRunner(CommandRunner): """Runner for Kubernetes commands.""" + _MAX_RETRIES_FOR_RSYNC = 3 + def __init__( self, node: Tuple[Tuple[str, Optional[str]], str], @@ -798,7 +817,7 @@ def rsync( # Advanced options. log_path: str = os.devnull, stream_logs: bool = True, - max_retry: int = 1, + max_retry: int = _MAX_RETRIES_FOR_RSYNC, ) -> None: """Uses 'rsync' to sync 'source' to 'target'. diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 303e3ddad99..acb8fb9f490 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -50,17 +50,27 @@ def get_parallel_threads() -> int: return max(4, cpu_count - 1) -def run_in_parallel(func: Callable, args: Iterable[Any]) -> List[Any]: +def run_in_parallel(func: Callable, + args: Iterable[Any], + num_threads: Optional[int] = None) -> List[Any]: """Run a function in parallel on a list of arguments. The function 'func' should raise a CommandError if the command fails. + Args: + func: The function to run in parallel + args: Iterable of arguments to pass to func + num_threads: Number of threads to use. If None, uses + get_parallel_threads() + Returns: A list of the return values of the function func, in the same order as the arguments. """ # Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long - with pool.ThreadPool(processes=get_parallel_threads()) as p: + processes = num_threads if num_threads is not None else get_parallel_threads( + ) + with pool.ThreadPool(processes=processes) as p: # Run the function in parallel on the arguments, keeping the order. return list(p.imap(func, args))