From 8ac1e99369c8493cd8dfc6d47f49ca9d361ef0bb Mon Sep 17 00:00:00 2001 From: Donny Greenberg <15992114+dongreenberg@users.noreply.github.com> Date: Tue, 9 Jul 2024 16:09:05 -0400 Subject: [PATCH] Allow passing Sky kwargs to ondemand_cluster (#978) (cherry picked from commit 50ca4c9e46a889ec106397432f87a035794a4108) --- .../resources/hardware/cluster_factory.py | 9 +++ .../resources/hardware/on_demand_cluster.py | 75 ++++++++++++------- tests/fixtures/on_demand_cluster_fixtures.py | 3 +- 3 files changed, 57 insertions(+), 30 deletions(-) diff --git a/runhouse/resources/hardware/cluster_factory.py b/runhouse/resources/hardware/cluster_factory.py index e7995d4e9..ca9940eaa 100644 --- a/runhouse/resources/hardware/cluster_factory.py +++ b/runhouse/resources/hardware/cluster_factory.py @@ -313,6 +313,7 @@ def ondemand_cluster( memory: Union[int, str, None] = None, disk_size: Union[int, str, None] = None, open_ports: Union[int, str, List[int], None] = None, + sky_kwargs: Dict = None, server_port: int = None, server_host: int = None, server_connection_type: Union[ServerConnectionType, str] = None, @@ -346,6 +347,14 @@ def ondemand_cluster( disk_size (int or str, optional): Amount of disk space to use for the cluster, e.g. "100" or "100+". open_ports (int or str or List[int], optional): Ports to open in the cluster's security group. Note that you are responsible for ensuring that the applications listening on these ports are secure. + sky_kwargs (dict, optional): Additional keyword arguments to pass to the SkyPilot `Resource` or + `launch` APIs. Should be a dict of the form + `{"resources": {}, "launch": {}}`, where resources_kwargs and + launch_kwargs will be passed to the SkyPilot Resources API + (See `SkyPilot docs `_) + and `launch` API (See + `SkyPilot docs `_), respectively. + Any arguments which duplicate those passed to the `ondemand_cluster` factory method will raise an error. server_port (bool, optional): Port to use for the server. If not provided will use 80 for a ``server_connection_type`` of ``none``, 443 for ``tls`` and ``32300`` for all other SSH connection types. server_host (bool, optional): Host from which the server listens for traffic (i.e. the --host argument diff --git a/runhouse/resources/hardware/on_demand_cluster.py b/runhouse/resources/hardware/on_demand_cluster.py index 2a038e809..bc57bf66d 100644 --- a/runhouse/resources/hardware/on_demand_cluster.py +++ b/runhouse/resources/hardware/on_demand_cluster.py @@ -58,6 +58,7 @@ def __init__( domain: str = None, den_auth: bool = False, region=None, + sky_kwargs: Dict = None, **kwargs, # We have this here to ignore extra arguments when calling from from_config ): """ @@ -95,6 +96,7 @@ def __init__( self.region = region self.memory = memory self.disk_size = disk_size + self.sky_kwargs = sky_kwargs or {} self.stable_internal_external_ips = kwargs.get( "stable_internal_external_ips", None @@ -154,6 +156,9 @@ def config(self, condensed=True): "image_id", "region", "stable_internal_external_ips", + "memory", + "disk_size", + "sky_kwargs", ], ) config["autostop_mins"] = self._autostop_mins @@ -431,36 +436,48 @@ def up(self): if self.provider != "cheapest" else None ) - task.set_resources( - sky.Resources( - # TODO: confirm if passing instance type in old way (without --) works when provider is k8s - cloud=cloud_provider, - instance_type=self.get_instance_type(), - accelerators=self.accelerators(), - cpus=self.num_cpus(), - memory=self.memory, - region=self.region or configs.get("default_region"), - disk_size=self.disk_size, - ports=self.open_ports, - image_id=self.image_id, - use_spot=self.use_spot, + try: + task.set_resources( + sky.Resources( + # TODO: confirm if passing instance type in old way (without --) works when provider is k8s + cloud=cloud_provider, + instance_type=self.get_instance_type(), + accelerators=self.accelerators(), + cpus=self.num_cpus(), + memory=self.memory, + region=self.region or configs.get("default_region"), + disk_size=self.disk_size, + ports=self.open_ports, + image_id=self.image_id, + use_spot=self.use_spot, + **self.sky_kwargs.get("resources", {}), + ) ) - ) - if self.image_id: - import os - - docker_env_vars = {} - for env_var in DOCKER_LOGIN_ENV_VARS: - if os.getenv(env_var): - docker_env_vars[env_var] = os.getenv(env_var) - if docker_env_vars: - task.update_envs(docker_env_vars) - sky.launch( - task, - cluster_name=self.name, - idle_minutes_to_autostop=self._autostop_mins, - down=True, - ) + if self.image_id: + import os + + docker_env_vars = {} + for env_var in DOCKER_LOGIN_ENV_VARS: + if os.getenv(env_var): + docker_env_vars[env_var] = os.getenv(env_var) + if docker_env_vars: + task.update_envs(docker_env_vars) + sky.launch( + task, + cluster_name=self.name, + idle_minutes_to_autostop=self._autostop_mins, + down=True, + **self.sky_kwargs.get("launch", {}), + ) + # Make sure no args are passed both in sky_kwargs and as explicit args + except TypeError as e: + if "got multiple values for keyword argument" in str(e): + raise TypeError( + f"{str(e)}. If argument is in `sky_kwargs`, it may need to be passed directly through the " + f"ondemand_cluster constructor (see `ondemand_cluster docs " + f"`_)." + ) + raise e self._update_from_sky_status() diff --git a/tests/fixtures/on_demand_cluster_fixtures.py b/tests/fixtures/on_demand_cluster_fixtures.py index eec8da49b..7808461a8 100644 --- a/tests/fixtures/on_demand_cluster_fixtures.py +++ b/tests/fixtures/on_demand_cluster_fixtures.py @@ -55,7 +55,8 @@ def ondemand_aws_cluster(request): "provider": "aws", "image_id": "docker:nvcr.io/nvidia/pytorch:23.10-py3", "region": "us-east-2", - "default_env": rh.env(reqs=["skypilot"], working_dir=None), + "default_env": rh.env(reqs=["ray==2.30.0"], working_dir=None), + "sky_kwargs": {"launch": {"retry_until_up": True}}, } cluster = setup_test_cluster(args, request, create_env=True) return cluster