Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebase multithread #8

Merged
merged 10 commits into from
Nov 5, 2024
1 change: 1 addition & 0 deletions docs/source/getting-started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ The :code:`~/.oci/config` file should contain the following fields:
fingerprint=aa:bb:cc:dd:ee:ff:gg:hh:ii:jj:kk:ll:mm:nn:oo:pp
tenancy=ocid1.tenancy.oc1..aaaaaaaa
region=us-sanjose-1
# Note that we should avoid using full home path for the key_file configuration, e.g. use ~/.oci instead of /home/username/.oci
key_file=~/.oci/oci_api_key.pem


Expand Down
122 changes: 94 additions & 28 deletions sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import json
import time
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from sky import exceptions
from sky import sky_logging
Expand All @@ -23,6 +23,8 @@

POLL_INTERVAL = 2
_TIMEOUT_FOR_POD_TERMINATION = 60 # 1 minutes
_MAX_RETRIES = 3
NUM_THREADS = subprocess_utils.get_parallel_threads() * 2

logger = sky_logging.init_logger(__name__)
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
Expand Down Expand Up @@ -303,6 +305,33 @@ def _check_init_containers(pod):
time.sleep(1)


def _run_function_with_retries(func: Callable,
operation_name: str,
max_retries: int = _MAX_RETRIES,
retry_delay: int = 5) -> Any:
"""Runs a function with retries on Kubernetes errors.

Args:
func: Function to retry
operation_name: Name of the operation for logging
max_retries: Maximum number of retry attempts
retry_delay: Delay between retries in seconds

Raises:
The last exception encountered if all retries fail.
"""
for attempt in range(max_retries + 1):
try:
return func()
except config_lib.KubernetesError:
if attempt < max_retries:
logger.warning(f'Failed to {operation_name} - '
f'retrying in {retry_delay} seconds.')
time.sleep(retry_delay)
else:
raise


def _set_env_vars_in_pods(namespace: str, context: Optional[str],
new_pods: List):
"""Setting environment variables in pods.
Expand All @@ -322,14 +351,27 @@ def _set_env_vars_in_pods(namespace: str, context: Optional[str],
"""
set_k8s_env_var_cmd = docker_utils.SETUP_ENV_VARS_CMD

for new_pod in new_pods:
def _set_env_vars_thread(new_pod):
pod_name = new_pod.metadata.name
logger.info(f'{"-"*20}Start: Set up env vars in pod {pod_name!r} '
f'{"-"*20}')
runner = command_runner.KubernetesCommandRunner(
((namespace, context), new_pod.metadata.name))
rc, stdout, _ = runner.run(set_k8s_env_var_cmd,
require_outputs=True,
stream_logs=False)
_raise_command_running_error('set env vars', set_k8s_env_var_cmd,
new_pod.metadata.name, rc, stdout)
((namespace, context), pod_name))

def _run_env_vars_cmd():
rc, stdout, _ = runner.run(set_k8s_env_var_cmd,
require_outputs=True,
stream_logs=False)
_raise_command_running_error('set env vars', set_k8s_env_var_cmd,
pod_name, rc, stdout)

_run_function_with_retries(_run_env_vars_cmd,
f'set env vars in pod {pod_name}')
logger.info(f'{"-"*20}End: Set up env vars in pod {pod_name!r} '
f'{"-"*20}')

subprocess_utils.run_in_parallel(_set_env_vars_thread, new_pods,
NUM_THREADS)


def _check_user_privilege(namespace: str, context: Optional[str],
Expand All @@ -349,23 +391,37 @@ def _check_user_privilege(namespace: str, context: Optional[str],
' fi; '
'fi')

for new_node in new_nodes:
runner = command_runner.KubernetesCommandRunner(
((namespace, context), new_node.metadata.name))
# This check needs to run on a per-image basis, so running the check on
# any one pod is sufficient.
new_node = new_nodes[0]
pod_name = new_node.metadata.name

runner = command_runner.KubernetesCommandRunner(
((namespace, context), pod_name))
logger.info(f'{"-"*20}Start: Check user privilege in pod {pod_name!r} '
f'{"-"*20}')

def _run_privilege_check():
rc, stdout, stderr = runner.run(check_k8s_user_sudo_cmd,
require_outputs=True,
separate_stderr=True,
stream_logs=False)
_raise_command_running_error('check user privilege',
check_k8s_user_sudo_cmd,
new_node.metadata.name, rc,
check_k8s_user_sudo_cmd, pod_name, rc,
stdout + stderr)
if stdout == str(exceptions.INSUFFICIENT_PRIVILEGES_CODE):
raise config_lib.KubernetesError(
'Insufficient system privileges detected. '
'Ensure the default user has root access or '
'"sudo" is installed and the user is added to the sudoers '
'from the image.')
return stdout

stdout = _run_function_with_retries(
_run_privilege_check, f'check user privilege in pod {pod_name!r}')

if stdout == str(exceptions.INSUFFICIENT_PRIVILEGES_CODE):
raise config_lib.KubernetesError(
'Insufficient system privileges detected. '
'Ensure the default user has root access or '
'"sudo" is installed and the user is added to the sudoers '
'from the image.')
logger.info(f'{"-"*20}End: Check user privilege in pod {pod_name!r} '
f'{"-"*20}')


def _setup_ssh_in_pods(namespace: str, context: Optional[str],
Expand Down Expand Up @@ -404,14 +460,19 @@ def _setup_ssh_thread(new_node):
runner = command_runner.KubernetesCommandRunner(
((namespace, context), pod_name))
logger.info(f'{"-"*20}Start: Set up SSH in pod {pod_name!r} {"-"*20}')
rc, stdout, _ = runner.run(set_k8s_ssh_cmd,
require_outputs=True,
stream_logs=False)
_raise_command_running_error('setup ssh', set_k8s_ssh_cmd, pod_name, rc,
stdout)

def _run_ssh_setup():
rc, stdout, _ = runner.run(set_k8s_ssh_cmd,
require_outputs=True,
stream_logs=False)
_raise_command_running_error('setup ssh', set_k8s_ssh_cmd, pod_name,
rc, stdout)

_run_function_with_retries(_run_ssh_setup,
f'setup ssh in pod {pod_name!r}')
logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}')

subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes)
subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes, NUM_THREADS)


def _label_pod(namespace: str, context: Optional[str], pod_name: str,
Expand Down Expand Up @@ -767,12 +828,17 @@ def terminate_instances(
def _is_head(pod) -> bool:
return pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head'

for pod_name, pod in pods.items():
logger.debug(f'Terminating instance {pod_name}: {pod}')
def _terminate_pod_thread(pod_info):
pod_name, pod = pod_info
if _is_head(pod) and worker_only:
continue
return
logger.debug(f'Terminating instance {pod_name}: {pod}')
_terminate_node(namespace, context, pod_name)

# Run pod termination in parallel
subprocess_utils.run_in_parallel(_terminate_pod_thread, pods.items(),
NUM_THREADS)


def get_cluster_info(
region: str,
Expand Down
25 changes: 22 additions & 3 deletions sky/utils/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,23 @@ def _rsync(
rsync_command.append(prefix_command)
rsync_command += ['rsync', RSYNC_DISPLAY_OPTION]

def _get_remote_home_dir_with_retry():
backoff = common_utils.Backoff(initial_backoff=1,
max_backoff_factor=5)
retries_left = max_retry
assert retries_left > 0, f'max_retry {max_retry} must be positive.'
while retries_left >= 0:
try:
return get_remote_home_dir()
except Exception: # pylint: disable=broad-except
if retries_left == 0:
raise
sleep_time = backoff.current_backoff()
logger.warning(f'Failed to get remote home dir '
f'- retrying in {sleep_time} seconds.')
retries_left -= 1
time.sleep(sleep_time)

# --filter
# The source is a local path, so we need to resolve it.
resolved_source = pathlib.Path(source).expanduser().resolve()
Expand All @@ -261,7 +278,7 @@ def _rsync(
if up:
resolved_target = target
if target.startswith('~'):
remote_home_dir = get_remote_home_dir()
remote_home_dir = _get_remote_home_dir_with_retry()
resolved_target = target.replace('~', remote_home_dir)
full_source_str = str(resolved_source)
if resolved_source.is_dir():
Expand All @@ -273,7 +290,7 @@ def _rsync(
else:
resolved_source = source
if source.startswith('~'):
remote_home_dir = get_remote_home_dir()
remote_home_dir = _get_remote_home_dir_with_retry()
resolved_source = source.replace('~', remote_home_dir)
rsync_command.extend([
f'{node_destination}:{resolved_source!r}',
Expand Down Expand Up @@ -656,6 +673,8 @@ def rsync(
class KubernetesCommandRunner(CommandRunner):
"""Runner for Kubernetes commands."""

_MAX_RETRIES_FOR_RSYNC = 3

def __init__(
self,
node: Tuple[Tuple[str, Optional[str]], str],
Expand Down Expand Up @@ -798,7 +817,7 @@ def rsync(
# Advanced options.
log_path: str = os.devnull,
stream_logs: bool = True,
max_retry: int = 1,
max_retry: int = _MAX_RETRIES_FOR_RSYNC,
) -> None:
"""Uses 'rsync' to sync 'source' to 'target'.

Expand Down
14 changes: 12 additions & 2 deletions sky/utils/subprocess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,27 @@ def get_parallel_threads() -> int:
return max(4, cpu_count - 1)


def run_in_parallel(func: Callable, args: Iterable[Any]) -> List[Any]:
def run_in_parallel(func: Callable,
args: Iterable[Any],
num_threads: Optional[int] = None) -> List[Any]:
"""Run a function in parallel on a list of arguments.

The function 'func' should raise a CommandError if the command fails.

Args:
func: The function to run in parallel
args: Iterable of arguments to pass to func
num_threads: Number of threads to use. If None, uses
get_parallel_threads()

Returns:
A list of the return values of the function func, in the same order as the
arguments.
"""
# Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long
with pool.ThreadPool(processes=get_parallel_threads()) as p:
processes = num_threads if num_threads is not None else get_parallel_threads(
)
with pool.ThreadPool(processes=processes) as p:
# Run the function in parallel on the arguments, keeping the order.
return list(p.imap(func, args))

Expand Down
Loading