Skip to content

Commit

Permalink
Replace address with head_ip
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Nov 4, 2024
1 parent 2b5e1df commit f2da1f7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 35 deletions.
49 changes: 24 additions & 25 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,12 @@ def __init__(
self._setup_creds(creds)

@property
def address(self):
def head_ip(self):
return self.ips[0] if isinstance(self.ips, List) else None

@address.setter
def address(self, addr):
self.ips = self.ips or [None]
self.ips[0] = addr
@property
def address(self):
return self.head_ip

@property
def client(self):
Expand All @@ -156,7 +155,7 @@ def check_connect_server():
self._update_from_sky_status(dryrun=False)
if not self._ping(retry=False):
raise ConnectionError(
f"Could not reach {self.name} {self.ips}. Is cluster up?"
f"Could not reach {self.name} {self.head_ip}. Is cluster up?"
)
if not check_connect_server():
raise ConnectionError(
Expand Down Expand Up @@ -480,7 +479,7 @@ def endpoint(self, external: bool = False):
(including the local connected port rather than the sever port). If cluster is not up, returns
`None``. (Default: ``False``)
"""
if not self.address or self.on_this_cluster():
if not self.head_ip or self.on_this_cluster():
return None

client_port = self.client_port or self.server_port
Expand Down Expand Up @@ -529,7 +528,7 @@ def server_address(self):
if self.server_host in [LOCALHOST, "localhost"]:
return LOCALHOST

return self.address
return self.head_ip

@property
def is_shared(self) -> bool:
Expand Down Expand Up @@ -563,7 +562,7 @@ def _command_runner(
"CommandRunner can only be instantiated for individual nodes"
)

node = node or self.address
node = node or self.head_ip

if (
hasattr(self, "launched_properties")
Expand Down Expand Up @@ -708,8 +707,8 @@ def _sync_runhouse_to_cluster(
if self.on_this_cluster():
return

if not self.address:
raise ValueError(f"No address set for cluster <{self.name}>. Is it up?")
if not self.head_ip:
raise ValueError(f"No IPs set for cluster <{self.name}>. Is it up?")

env = env or self.default_env

Expand Down Expand Up @@ -955,8 +954,8 @@ def connect_tunnel(self, force_reconnect=False):
)

def connect_server_client(self, force_reconnect=False):
if not self.address:
raise ValueError(f"No address set for cluster <{self.name}>. Is it up?")
if not self.head_ip:
raise ValueError(f"No IPs set for cluster <{self.name}>. Is it up?")

if self.server_connection_type == ServerConnectionType.SSH:
# For a password cluster, the 'ssh_tunnel' command assumes a Control Master is already set up with
Expand Down Expand Up @@ -1057,7 +1056,7 @@ def ssh_tunnel(
)

return ssh_tunnel(
address=self.address,
address=self.head_ip,
ssh_creds=self.creds_values,
docker_user=self.docker_user,
local_port=local_port,
Expand Down Expand Up @@ -1092,31 +1091,31 @@ def _use_custom_certs(self):

def _start_ray_workers(self, ray_port, env):
for host in self.ips:
if host == self.address:
if host == self.head_ip:
# This is the master node, skip
continue
logger.info(
f"Starting Ray on worker {host} with head node at {self.address}:{ray_port}."
f"Starting Ray on worker {host} with head node at {self.head_ip}:{ray_port}."
)
self.run(
commands=[
f"ray start --address={self.address}:{ray_port} --disable-usage-stats",
f"ray start --address={self.head_ip}:{ray_port} --disable-usage-stats",
],
node=host,
env=env,
)

def _run_cli_commands_on_cluster_helper(self, commands: List[str]):
if self.on_this_cluster():
return self.run(commands=commands, env=self._default_env, node=self.address)
return self.run(commands=commands, env=self._default_env, node=self.head_ip)
else:
if self._default_env:
commands = [self._default_env._full_command(cmd) for cmd in commands]
return self._run_commands_with_runner(
commands=commands,
cmd_prefix="",
env_vars=self._default_env.env_vars if self._default_env else {},
node=self.address,
node=self.head_ip,
require_outputs=False,
)

Expand Down Expand Up @@ -1190,7 +1189,7 @@ def _start_or_restart_helper(
# Rebuild on restart to ensure the correct subject name is included in the cert SAN
# Cert subject name needs to match the target (IP address or domain)
self.cert_config.generate_certs(
address=self.address, domain=self.domain
address=self.head_ip, domain=self.domain
)
self._copy_certs_to_cluster()

Expand Down Expand Up @@ -1503,14 +1502,14 @@ def rsync(
from runhouse.resources.hardware.sky_command_runner import SshMode

# If no address provided explicitly use the head node address
node = node or self.address
node = node or self.head_ip
# FYI, could be useful: https://github.com/gchamon/sysrsync
if contents:
source = source + "/" if not source.endswith("/") else source
dest = dest + "/" if not dest.endswith("/") else dest

# If we're already on this cluster (and node, if multinode), this is just a local rsync
if self.on_this_cluster() and node == self.address:
if self.on_this_cluster() and node == self.head_ip:
if Path(source).expanduser().resolve() == Path(dest).expanduser().resolve():
return

Expand Down Expand Up @@ -1590,15 +1589,15 @@ def ssh(self):
"""
creds = self.creds_values
_run_ssh_command(
address=self.address,
address=self.head_ip,
ssh_user=creds["ssh_user"],
ssh_port=self.ssh_port,
ssh_private_key=creds["ssh_private_key"],
docker_user=self.docker_user,
)

def _ping(self, timeout=5, retry=False):
if not self.address:
if not self.head_ip:
return False

def run_ssh_call():
Expand Down Expand Up @@ -1755,7 +1754,7 @@ def _run_commands_with_runner(
commands = [commands]

# If no address provided explicitly use the head node address
node = node or self.address
node = node or self.head_ip

return_codes = []

Expand Down
20 changes: 10 additions & 10 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def client(self):
try:
return super().client
except ValueError as e:
if not self.address:
if not self.ips:
# Try loading in from local Sky DB
self._update_from_sky_status(dryrun=True)
if not self.address:
if not self.ips:
raise ValueError(
f"Could not determine address for ondemand cluster <{self.name}>. "
f"Could not determine ips for ondemand cluster <{self.name}>. "
"Up the cluster with `cluster.up_if_not`."
)
return super().client
Expand Down Expand Up @@ -197,7 +197,7 @@ def config(self, condensed=True):
return config

def endpoint(self, external: bool = False):
if not self.address or self.on_this_cluster():
if not self.ips or self.on_this_cluster():
return None

try:
Expand All @@ -215,7 +215,7 @@ def _copy_sky_yaml_from_cluster(self, abs_yaml_path: str):
# Save SSH info to the ~/.ssh/config
ray_yaml = yaml.safe_load(open(abs_yaml_path, "r"))
backend_utils.SSHConfigHelper.add_cluster(
self.name, [self.address], ray_yaml["auth"]
self.name, [self.head_up], ray_yaml["auth"]
)

@staticmethod
Expand Down Expand Up @@ -361,7 +361,7 @@ def _start_ray_workers(self, ray_port, env):
"handle"
].stable_internal_external_ips
for internal, external in stable_internal_external_ips:
if external == self.address:
if external == self.head_ip:
internal_head_ip = internal
else:
# NOTE: Using external worker address here because we're running from local
Expand All @@ -385,9 +385,9 @@ def _start_ray_workers(self, ray_port, env):
def _populate_connection_from_status_dict(self, cluster_dict: Dict[str, Any]):
if cluster_dict and cluster_dict["status"].name in ["UP", "INIT"]:
handle = cluster_dict["handle"]
self.address = handle.head_ip
head_ip = handle.head_ip
self.stable_internal_external_ips = handle.stable_internal_external_ips
if self.stable_internal_external_ips is None or self.address is None:
if self.stable_internal_external_ips is None or head_ip is None:
raise ValueError(
"Sky's cluster status does not have the necessary information to connect to the cluster. Please check if the cluster is up via `sky status`. Consider bringing down the cluster with `sky down` if you are still having issues."
)
Expand Down Expand Up @@ -577,7 +577,7 @@ def up(self):
logger.info(
f"Cluster has been launched with the custom domain '{self.domain}'. "
"Please add an A record to your DNS provider to point this domain to the cluster's "
f"public IP address ({self.address}) to ensure successful requests."
f"public IP address ({self.head_ip}) to ensure successful requests."
)

self.restart_server()
Expand Down Expand Up @@ -625,7 +625,7 @@ def teardown(self):

# Stream logs
sky.down(self.name)
self.address = None
self.ips = None

def teardown_and_delete(self):
"""Teardown cluster and delete it from configs.
Expand Down

0 comments on commit f2da1f7

Please sign in to comment.