Skip to content

Commit

Permalink
support starting server in docker container for on-demand clusters
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Jun 10, 2024
1 parent 849535f commit e294372
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 15 deletions.
1 change: 1 addition & 0 deletions runhouse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
TEST_ORG = "test-org"

EMPTY_DEFAULT_ENV_NAME = "_cluster_default_env"
DEFAULT_DOCKER_CONTAINER_NAME = "sky_container"

# Constants for the status check
DOUBLE_SPACE_UNICODE = "\u00A0\u00A0"
Expand Down
4 changes: 2 additions & 2 deletions runhouse/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def _start_server(
try:
# Open and read the lines of the server logfile so we only print the most recent lines after starting
f = None
if screen and Path(SERVER_LOGFILE).exists():
if (screen or nohup) and Path(SERVER_LOGFILE).exists():
f = open(SERVER_LOGFILE, "r")
f.readlines() # Discard these, they're from the previous times the server was started

Expand All @@ -646,7 +646,7 @@ def _start_server(

server_started_str = "Uvicorn running on"
# Read and print the server logs until the
if screen:
if screen or nohup:
while not Path(SERVER_LOGFILE).exists():
time.sleep(1)
f = f or open(SERVER_LOGFILE, "r")
Expand Down
8 changes: 8 additions & 0 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def creds_values(self) -> Dict:

return self._creds.values

@property
def docker_user(self) -> Optional[str]:
return None

@property
def default_env(self):
from runhouse.resources.envs import Env
Expand Down Expand Up @@ -727,6 +731,7 @@ def ssh_tunnel(
return ssh_tunnel(
address=self.address,
ssh_creds=self.creds_values,
docker_user=self.docker_user,
local_port=local_port,
ssh_port=self.ssh_port,
remote_port=remote_port,
Expand Down Expand Up @@ -1107,6 +1112,7 @@ def _rsync(
**ssh_credentials,
ssh_control_name=ssh_control_name,
port=self.ssh_port,
docker_user=self.docker_user,
)
if not pwd:
if up:
Expand Down Expand Up @@ -1181,6 +1187,7 @@ def ssh(self):
ssh_user=creds["ssh_user"],
port=self.ssh_port,
ssh_private_key=creds["ssh_private_key"],
docker_user=self.docker_user,
)
subprocess.run(
runner._ssh_base_command(ssh_mode=SshMode.INTERACTIVE, port_forward=None)
Expand Down Expand Up @@ -1359,6 +1366,7 @@ def _run_commands_with_ssh(
**ssh_credentials,
ssh_control_name=ssh_control_name,
port=self.ssh_port,
docker_user=self.docker_user,
)

env_var_prefix = (
Expand Down
3 changes: 2 additions & 1 deletion runhouse/resources/hardware/cluster_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ def ondemand_cluster(
autostop_mins (int, optional): Number of minutes to keep the cluster up after inactivity,
or ``-1`` to keep cluster up indefinitely.
use_spot (bool, optional): Whether or not to use spot instance.
image_id (str, optional): Custom image ID for the cluster.
image_id (str, optional): Custom image ID for the cluster. If using a docker image, please use the following
string format: "docker:<registry>/<image>:<tag>".
region (str, optional): The region to use for the cluster.
memory (int or str, optional): Amount of memory to use for the cluster, e.g. "16" or "16+".
disk_size (int or str, optional): Amount of disk space to use for the cluster, e.g. "100" or "100+".
Expand Down
30 changes: 25 additions & 5 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
self.stable_internal_external_ips = kwargs.get(
"stable_internal_external_ips", None
)
self._docker_user = None

# Checks if state info is in local sky db, populates if so.
if not dryrun and not self.ips and not self.creds_values:
Expand All @@ -123,6 +124,22 @@ def autostop_mins(self, mins):
sky.autostop(self.name, mins, down=True)
self._autostop_mins = mins

@property
def docker_user(self) -> str:
if self._docker_user:
return self._docker_user

if not self.image_id:
return None

from runhouse.resources.hardware.sky_ssh_runner import get_docker_user

if not self._creds:
return
self._docker_user = get_docker_user(self, self._creds.values)

return self._docker_user

def config(self, condensed=True):
config = super().config(condensed)
self.save_attrs_to_config(
Expand Down Expand Up @@ -546,10 +563,13 @@ def ssh(self, node: str = None):
ip=node or self.address,
ssh_user=ssh_user,
port=self.ssh_port,
ssh_private_key=sky_key,
ssh_private_key=str(sky_key),
docker_user=self.docker_user,
)
subprocess.run(
runner._ssh_base_command(
ssh_mode=SshMode.INTERACTIVE, port_forward=None
)
cmd = runner.run(
cmd="bash --rcfile <(echo '. ~/.bashrc; conda deactivate')",
ssh_mode=SshMode.INTERACTIVE,
port_forward=None,
return_cmd=True,
)
subprocess.run(cmd, shell=True)
56 changes: 49 additions & 7 deletions runhouse/resources/hardware/sky_ssh_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from typing import Dict, List, Optional, Tuple, Union

from runhouse.constants import LOCALHOST
from runhouse.constants import DEFAULT_DOCKER_CONTAINER_NAME, LOCALHOST

from runhouse.globals import sky_ssh_runner_cache

Expand Down Expand Up @@ -45,6 +45,32 @@ def is_port_in_use(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0


def get_docker_user(cluster: "Cluster", ssh_creds: Dict) -> str:
"""Find docker container username."""
runner = SkySSHRunner(
ip=cluster.address,
ssh_user=ssh_creds.get("ssh_user", None),
port=cluster.ssh_port,
ssh_private_key=ssh_creds.get("ssh_private_key", None),
ssh_control_name=ssh_creds.get(
"ssh_control_name", f"{cluster.address}:{cluster.ssh_port}"
),
)
container_name = DEFAULT_DOCKER_CONTAINER_NAME
whoami_returncode, whoami_stdout, whoami_stderr = runner.run(
f"sudo docker exec {container_name} whoami",
stream_logs=False,
require_outputs=True,
)
assert whoami_returncode == 0, (
f"Failed to get docker container user. Return "
f"code: {whoami_returncode}, Error: {whoami_stderr}"
)
docker_user = whoami_stdout.strip()
logger.debug(f"Docker container user: {docker_user}")
return docker_user


class SkySSHRunner(SSHCommandRunner):
def __init__(
self,
Expand All @@ -68,6 +94,9 @@ def __init__(
docker_user,
disable_control_master,
)

# RH modified
self.docker_user = docker_user
self.tunnel_proc = None
self.local_bind_port = local_bind_port
self.remote_bind_port = None
Expand All @@ -90,7 +119,7 @@ def _ssh_base_command(
local, remote = fwd, fwd
else:
local, remote = fwd
logger.info(f"Forwarding port {local} to port {remote} on localhost.")
logger.debug(f"Forwarding port {local} to port {remote} on localhost.")
ssh += ["-L", f"{local}:localhost:{remote}"]
if self._docker_ssh_proxy_command is not None:
docker_ssh_proxy_command = self._docker_ssh_proxy_command(ssh)
Expand Down Expand Up @@ -179,6 +208,8 @@ def run(
"-i",
]

cmd = f"conda deactivate && {cmd}" if self.docker_user else cmd

command += [
shlex.quote(
f"true && source ~/.bashrc && export OMP_NUM_THREADS=1 "
Expand Down Expand Up @@ -271,6 +302,14 @@ def terminate(self):
port_forward=[(self.local_bind_port, self.remote_bind_port)],
)
)

self.tunnel_proc = None
self.local_bind_port = None
self.remote_bind_port = None

if "ControlMaster" not in port_fwd_cmd:
return

cancel_port_fwd = port_fwd_cmd.replace("-T", "-O cancel")
logger.debug(f"Running cancel command: {cancel_port_fwd}")
completed_cancel_cmd = subprocess.run(
Expand All @@ -285,10 +324,6 @@ def terminate(self):
f"Error: {completed_cancel_cmd.stderr}"
)

self.tunnel_proc = None
self.local_bind_port = None
self.remote_bind_port = None

def rsync(
self,
source: str,
Expand Down Expand Up @@ -434,6 +469,7 @@ def ssh_tunnel(
ssh_port: int = 22,
remote_port: Optional[int] = None,
num_ports_to_try: int = 0,
docker_user: Optional[str] = None,
) -> SkySSHRunner:
"""Initialize an ssh tunnel from a remote server to localhost
Expand Down Expand Up @@ -465,7 +501,12 @@ def ssh_tunnel(
remote_port = remote_port or local_port

tunnel = get_existing_sky_ssh_runner(address, ssh_port)
if tunnel and tunnel.ip == address and tunnel.remote_bind_port == remote_port:
tunnel_address = address if not docker_user else "localhost"
if (
tunnel
and tunnel.ip == tunnel_address
and tunnel.remote_bind_port == remote_port
):
logger.info(
f"SSH tunnel on to server's port {remote_port} "
f"via server's ssh port {ssh_port} already created with the cluster."
Expand Down Expand Up @@ -495,6 +536,7 @@ def ssh_tunnel(
ssh_private_key=ssh_creds.get("ssh_private_key"),
ssh_proxy_command=ssh_creds.get("ssh_proxy_command"),
ssh_control_name=ssh_control_name,
docker_user=docker_user,
port=ssh_port,
)
runner.tunnel(local_port, remote_port)
Expand Down

0 comments on commit e294372

Please sign in to comment.