Skip to content

Commit

Permalink
[k8s] Parallelize setup for faster multi-node provisioning (skypilot-…
Browse files Browse the repository at this point in the history
…org#4240)

* parallelize setup

* lint

* Add retries

* lint

* retry for get_remote_home_dir

* optimize privilege check

* parallelize termination

* increase num threads

* comments

* lint
  • Loading branch information
romilbhardwaj authored Nov 5, 2024
1 parent c24a0b3 commit 877d77f
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 33 deletions.
122 changes: 94 additions & 28 deletions sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import json
import time
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
import uuid

from sky import exceptions
Expand All @@ -24,6 +24,8 @@

POLL_INTERVAL = 2
_TIMEOUT_FOR_POD_TERMINATION = 60 # 1 minutes
_MAX_RETRIES = 3
NUM_THREADS = subprocess_utils.get_parallel_threads() * 2

logger = sky_logging.init_logger(__name__)
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
Expand Down Expand Up @@ -304,6 +306,33 @@ def _check_init_containers(pod):
time.sleep(1)


def _run_function_with_retries(func: Callable,
operation_name: str,
max_retries: int = _MAX_RETRIES,
retry_delay: int = 5) -> Any:
"""Runs a function with retries on Kubernetes errors.
Args:
func: Function to retry
operation_name: Name of the operation for logging
max_retries: Maximum number of retry attempts
retry_delay: Delay between retries in seconds
Raises:
The last exception encountered if all retries fail.
"""
for attempt in range(max_retries + 1):
try:
return func()
except config_lib.KubernetesError:
if attempt < max_retries:
logger.warning(f'Failed to {operation_name} - '
f'retrying in {retry_delay} seconds.')
time.sleep(retry_delay)
else:
raise


def _set_env_vars_in_pods(namespace: str, context: Optional[str],
new_pods: List):
"""Setting environment variables in pods.
Expand All @@ -323,14 +352,27 @@ def _set_env_vars_in_pods(namespace: str, context: Optional[str],
"""
set_k8s_env_var_cmd = docker_utils.SETUP_ENV_VARS_CMD

for new_pod in new_pods:
def _set_env_vars_thread(new_pod):
pod_name = new_pod.metadata.name
logger.info(f'{"-"*20}Start: Set up env vars in pod {pod_name!r} '
f'{"-"*20}')
runner = command_runner.KubernetesCommandRunner(
((namespace, context), new_pod.metadata.name))
rc, stdout, _ = runner.run(set_k8s_env_var_cmd,
require_outputs=True,
stream_logs=False)
_raise_command_running_error('set env vars', set_k8s_env_var_cmd,
new_pod.metadata.name, rc, stdout)
((namespace, context), pod_name))

def _run_env_vars_cmd():
rc, stdout, _ = runner.run(set_k8s_env_var_cmd,
require_outputs=True,
stream_logs=False)
_raise_command_running_error('set env vars', set_k8s_env_var_cmd,
pod_name, rc, stdout)

_run_function_with_retries(_run_env_vars_cmd,
f'set env vars in pod {pod_name}')
logger.info(f'{"-"*20}End: Set up env vars in pod {pod_name!r} '
f'{"-"*20}')

subprocess_utils.run_in_parallel(_set_env_vars_thread, new_pods,
NUM_THREADS)


def _check_user_privilege(namespace: str, context: Optional[str],
Expand All @@ -350,23 +392,37 @@ def _check_user_privilege(namespace: str, context: Optional[str],
' fi; '
'fi')

for new_node in new_nodes:
runner = command_runner.KubernetesCommandRunner(
((namespace, context), new_node.metadata.name))
# This check needs to run on a per-image basis, so running the check on
# any one pod is sufficient.
new_node = new_nodes[0]
pod_name = new_node.metadata.name

runner = command_runner.KubernetesCommandRunner(
((namespace, context), pod_name))
logger.info(f'{"-"*20}Start: Check user privilege in pod {pod_name!r} '
f'{"-"*20}')

def _run_privilege_check():
rc, stdout, stderr = runner.run(check_k8s_user_sudo_cmd,
require_outputs=True,
separate_stderr=True,
stream_logs=False)
_raise_command_running_error('check user privilege',
check_k8s_user_sudo_cmd,
new_node.metadata.name, rc,
check_k8s_user_sudo_cmd, pod_name, rc,
stdout + stderr)
if stdout == str(exceptions.INSUFFICIENT_PRIVILEGES_CODE):
raise config_lib.KubernetesError(
'Insufficient system privileges detected. '
'Ensure the default user has root access or '
'"sudo" is installed and the user is added to the sudoers '
'from the image.')
return stdout

stdout = _run_function_with_retries(
_run_privilege_check, f'check user privilege in pod {pod_name!r}')

if stdout == str(exceptions.INSUFFICIENT_PRIVILEGES_CODE):
raise config_lib.KubernetesError(
'Insufficient system privileges detected. '
'Ensure the default user has root access or '
'"sudo" is installed and the user is added to the sudoers '
'from the image.')
logger.info(f'{"-"*20}End: Check user privilege in pod {pod_name!r} '
f'{"-"*20}')


def _setup_ssh_in_pods(namespace: str, context: Optional[str],
Expand Down Expand Up @@ -405,14 +461,19 @@ def _setup_ssh_thread(new_node):
runner = command_runner.KubernetesCommandRunner(
((namespace, context), pod_name))
logger.info(f'{"-"*20}Start: Set up SSH in pod {pod_name!r} {"-"*20}')
rc, stdout, _ = runner.run(set_k8s_ssh_cmd,
require_outputs=True,
stream_logs=False)
_raise_command_running_error('setup ssh', set_k8s_ssh_cmd, pod_name, rc,
stdout)

def _run_ssh_setup():
rc, stdout, _ = runner.run(set_k8s_ssh_cmd,
require_outputs=True,
stream_logs=False)
_raise_command_running_error('setup ssh', set_k8s_ssh_cmd, pod_name,
rc, stdout)

_run_function_with_retries(_run_ssh_setup,
f'setup ssh in pod {pod_name!r}')
logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}')

subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes)
subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes, NUM_THREADS)


def _label_pod(namespace: str, context: Optional[str], pod_name: str,
Expand Down Expand Up @@ -765,12 +826,17 @@ def terminate_instances(
def _is_head(pod) -> bool:
return pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head'

for pod_name, pod in pods.items():
logger.debug(f'Terminating instance {pod_name}: {pod}')
def _terminate_pod_thread(pod_info):
pod_name, pod = pod_info
if _is_head(pod) and worker_only:
continue
return
logger.debug(f'Terminating instance {pod_name}: {pod}')
_terminate_node(namespace, context, pod_name)

# Run pod termination in parallel
subprocess_utils.run_in_parallel(_terminate_pod_thread, pods.items(),
NUM_THREADS)


def get_cluster_info(
region: str,
Expand Down
25 changes: 22 additions & 3 deletions sky/utils/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,23 @@ def _rsync(
rsync_command.append(prefix_command)
rsync_command += ['rsync', RSYNC_DISPLAY_OPTION]

def _get_remote_home_dir_with_retry():
backoff = common_utils.Backoff(initial_backoff=1,
max_backoff_factor=5)
retries_left = max_retry
assert retries_left > 0, f'max_retry {max_retry} must be positive.'
while retries_left >= 0:
try:
return get_remote_home_dir()
except Exception: # pylint: disable=broad-except
if retries_left == 0:
raise
sleep_time = backoff.current_backoff()
logger.warning(f'Failed to get remote home dir '
f'- retrying in {sleep_time} seconds.')
retries_left -= 1
time.sleep(sleep_time)

# --filter
# The source is a local path, so we need to resolve it.
resolved_source = pathlib.Path(source).expanduser().resolve()
Expand All @@ -261,7 +278,7 @@ def _rsync(
if up:
resolved_target = target
if target.startswith('~'):
remote_home_dir = get_remote_home_dir()
remote_home_dir = _get_remote_home_dir_with_retry()
resolved_target = target.replace('~', remote_home_dir)
full_source_str = str(resolved_source)
if resolved_source.is_dir():
Expand All @@ -273,7 +290,7 @@ def _rsync(
else:
resolved_source = source
if source.startswith('~'):
remote_home_dir = get_remote_home_dir()
remote_home_dir = _get_remote_home_dir_with_retry()
resolved_source = source.replace('~', remote_home_dir)
rsync_command.extend([
f'{node_destination}:{resolved_source!r}',
Expand Down Expand Up @@ -656,6 +673,8 @@ def rsync(
class KubernetesCommandRunner(CommandRunner):
"""Runner for Kubernetes commands."""

_MAX_RETRIES_FOR_RSYNC = 3

def __init__(
self,
node: Tuple[Tuple[str, Optional[str]], str],
Expand Down Expand Up @@ -798,7 +817,7 @@ def rsync(
# Advanced options.
log_path: str = os.devnull,
stream_logs: bool = True,
max_retry: int = 1,
max_retry: int = _MAX_RETRIES_FOR_RSYNC,
) -> None:
"""Uses 'rsync' to sync 'source' to 'target'.
Expand Down
14 changes: 12 additions & 2 deletions sky/utils/subprocess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,27 @@ def get_parallel_threads() -> int:
return max(4, cpu_count - 1)


def run_in_parallel(func: Callable, args: Iterable[Any]) -> List[Any]:
def run_in_parallel(func: Callable,
args: Iterable[Any],
num_threads: Optional[int] = None) -> List[Any]:
"""Run a function in parallel on a list of arguments.
The function 'func' should raise a CommandError if the command fails.
Args:
func: The function to run in parallel
args: Iterable of arguments to pass to func
num_threads: Number of threads to use. If None, uses
get_parallel_threads()
Returns:
A list of the return values of the function func, in the same order as the
arguments.
"""
# Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long
with pool.ThreadPool(processes=get_parallel_threads()) as p:
processes = num_threads if num_threads is not None else get_parallel_threads(
)
with pool.ThreadPool(processes=processes) as p:
# Run the function in parallel on the arguments, keeping the order.
return list(p.imap(func, args))

Expand Down

0 comments on commit 877d77f

Please sign in to comment.