Skip to content

Commit

Permalink
Refactor writing down cluster creds
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Oct 30, 2024
1 parent 9ad8e1f commit 3e83757
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 30 deletions.
23 changes: 18 additions & 5 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,10 @@ def _should_save_creds(self, folder: str = None) -> bool:
# if not self.rns_address => we are saving the cluster first time in den
# else, need to check if the username of the current saver is included in the rns_address.
should_save_creds = (
not self.rns_address or local_default_folder in self.rns_address
) and isinstance(self._creds, Secret)
(not self.rns_address or local_default_folder in self.rns_address)
and self._creds
and isinstance(self._creds, Secret)
)

if should_save_creds:
# update secret name if it already exists in den w/ different config, avoid overwriting
Expand All @@ -372,7 +374,6 @@ def _save_sub_resources(self, folder: str = None):
from runhouse.resources.envs import Env

if self._should_save_creds(folder):
# TODO - check against existing secrets if already there, rename if conflict
self._creds.save(folder=folder)

if self._default_env and isinstance(self._default_env, Env):
Expand All @@ -397,9 +398,20 @@ def from_name(
_resolve_children=_resolve_children,
)
if cluster and cluster._creds and not dryrun:
from runhouse.resources.secrets.utils import _write_creds_to_local
from runhouse.resources.secrets import Secret
from runhouse.resources.secrets.provider_secrets.ssh_secret import SSHSecret

if isinstance(cluster._creds, SSHSecret):
cluster._creds.write()
elif isinstance(cluster._creds, Secret):
# old version of cluster creds or password only
private_key_path = cluster._creds.values.get("ssh_private_key")
if private_key_path:
SSHSecret._write_to_file(
path=private_key_path,
values=cluster._creds.values,
)

_write_creds_to_local(cluster.creds_values)
return cluster

@classmethod
Expand Down Expand Up @@ -531,6 +543,7 @@ def is_shared(self) -> bool:
if ssh_private_key:
ssh_private_key_path = Path(ssh_private_key).expanduser()
secrets_base_dir = Path(Secret.DEFAULT_DIR).expanduser()

# Check if the key path is saved down in the local .rh directory, which we only do for shared credentials
if str(ssh_private_key_path).startswith(str(secrets_base_dir)):
return True
Expand Down
2 changes: 1 addition & 1 deletion runhouse/resources/secrets/provider_secrets/ssh_secret.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _write_to_file(
priv_key_path = Path(os.path.expanduser(priv_key_path))
pub_key_path = Path(f"{os.path.expanduser(priv_key_path)}.pub")

if priv_key_path.exists() and pub_key_path.exists():
if priv_key_path.exists() or pub_key_path.exists():
if values == self._from_path(path=path):
logger.info(f"Secrets already exist in {path}. Skipping.")
self.path = path
Expand Down
24 changes: 0 additions & 24 deletions runhouse/resources/secrets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,27 +117,3 @@ def _check_file_for_mismatches(path, existing_vals, new_vals, overwrite):
)
return True
return False


def _write_creds_to_local(creds):
if not creds:
return

private_key_path = creds.get("ssh_private_key")
if not private_key_path or Path(private_key_path).expanduser().exists():
return

private_key_value = creds.get("private_key")
public_key_value = creds.get("public_key")
private_file_path = Path(private_key_path).expanduser()

if private_key_value:
with open(str(private_file_path), "w") as f:
f.write(private_key_value)
private_file_path.chmod(0o600)
if public_key_value:
public_file_path = Path(f"{str(private_file_path)}.pub")
if not public_file_path.exists():
with open(str(public_file_path), "w") as f:
f.write(public_key_value)
public_file_path.chmod(0o600)

0 comments on commit 3e83757

Please sign in to comment.