Skip to content

Commit

Permalink
Improvements for adding SSH instances (#1202)
Browse files Browse the repository at this point in the history
* Change the GPU's full name to a short one

* SSH instances must be filtered with other instances

* Fixed the region for the instance

* Check the instance in the existing ones before adding them
  • Loading branch information
Sergey Mezentsev authored May 8, 2024
1 parent c6315ab commit b615a12
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
BackendType.AWS,
BackendType.AZURE,
BackendType.GCP,
BackendType.REMOTE,
]
BACKENDS_WITH_CREATE_INSTANCE_SUPPORT = [
BackendType.AWS,
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/remote/provisioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
InstanceType,
Resources,
)
from dstack._internal.utils.gpu import convert_gpu_name
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -188,9 +189,10 @@ def get_shim_healthcheck(client: paramiko.SSHClient) -> str:


def host_info_to_instance_type(host_info: Dict[str, Any]) -> InstanceType:
gpu_name = convert_gpu_name(host_info["gpu_name"])
if host_info.get("gpu_count", 0):
gpu_memory = int(host_info["gpu_memory"].lower().replace("mib", "").strip())
gpus = [Gpu(name=host_info["gpu_name"], memory_mib=gpu_memory)] * host_info["gpu_count"]
gpus = [Gpu(name=gpu_name, memory_mib=gpu_memory)] * host_info["gpu_count"]
else:
gpus = []
instance_type = InstanceType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,14 @@ async def add_remote(instance_id: UUID) -> None:
continue
internal_ip = addresses[0] if addresses else None

region = instance.region

jpd = JobProvisioningData(
backend=BackendType.REMOTE,
instance_type=instance_type,
instance_id="instance_id",
hostname=remote_details.host,
region="remote",
region=region,
price=0,
internal_ip=internal_ip,
username=remote_details.ssh_user,
Expand All @@ -270,12 +272,10 @@ async def add_remote(instance_id: UUID) -> None:
instance.status = InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING
instance.backend = BackendType.REMOTE

instance.region = "remote"

instance_offer = InstanceOfferWithAvailability(
backend=BackendType.REMOTE,
instance=instance_type,
region="remote",
region=region,
price=0,
availability=InstanceAvailability.AVAILABLE,
instance_runtime=InstanceRuntime.SHIM,
Expand Down
14 changes: 9 additions & 5 deletions src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ async def add_remote(
pools = await list_project_pool_models(session, project)
for pool in pools:
for instance in pool.instances:
if instance.deleted:
continue
if instance.remote_connection_info is not None:
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
if rci.host == host and rci.port == port and rci.ssh_user == ssh_user:
Expand All @@ -272,12 +274,14 @@ async def add_remote(
instance_resource = Resources(cpus=2, memory_mib=8, gpus=[], spot=False)
instance_type = InstanceType(name="ssh", resources=instance_resource)

local = JobProvisioningData(
host_region = region if region is not None else "remote"

remote = JobProvisioningData(
backend=BackendType.REMOTE,
instance_type=instance_type,
instance_id=instance_name,
hostname=host,
region=region or "remote",
region=host_region,
internal_ip=None,
price=0,
username=ssh_user,
Expand All @@ -289,7 +293,7 @@ async def add_remote(
offer = InstanceOfferWithAvailability(
backend=BackendType.REMOTE,
instance=instance_type,
region=region or "remote",
region=host_region,
price=0.0,
availability=InstanceAvailability.AVAILABLE,
)
Expand All @@ -306,7 +310,7 @@ async def add_remote(
created_at=common_utils.get_current_datetime(),
started_at=common_utils.get_current_datetime(),
status=InstanceStatus.PENDING,
job_provisioning_data=local.json(),
job_provisioning_data=remote.json(),
remote_connection_info=ssh_connection_info,
offer=offer.json(),
region=offer.region,
Expand Down Expand Up @@ -342,7 +346,7 @@ def filter_pool_instances(
continue

if instance.backend == BackendType.REMOTE:
instances.append(instance)
candidates.append(instance)
continue

# TODO: remove on prod
Expand Down
27 changes: 27 additions & 0 deletions src/dstack/_internal/utils/gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import re


def convert_gpu_name(name: str) -> str:
"""Convert gpu_name from nvidia-smi to short version"""
# https://github.com/NVIDIA/open-gpu-kernel-modules/
name = name.replace("NVIDIA ", "")
name = name.replace("Tesla ", "")
name = name.replace("Quadro ", "")
name = name.replace("GeForce ", "")

if "GH200" in name:
return "GH200"

if "RTX A" in name:
name = name.replace("RTX A", "A")
m = re.search(r"(A\d+)", name)
if m is not None:
return m.group(0)
return name.replace(" ", "")

name = name.replace(" Ti", "Ti")
name = name.replace("RTX ", "RTX")
m = re.search(r"([A|H|L|P|T|V]\d+[Ti]?)", name)
if m is not None:
return m.group(0)
return name.replace(" ", "")
19 changes: 19 additions & 0 deletions src/tests/_internal/utils/test_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest

from dstack._internal.utils.gpu import convert_gpu_name

TESTS = [
("NVIDIA GeForce RTX 4060 Ti", "RTX4060Ti"),
("NVIDIA GeForce RTX 4060", "RTX4060"),
("NVIDIA L4", "L4"),
("NVIDIA GH200 120GB", "GH200"),
("NVIDIA A100-SXM4-80GB", "A100"),
("NVIDIA A10G", "A10"),
("Tesla T4", "T4"),
]


class TestConvertGpuName:
@pytest.mark.parametrize("test_input,expected", TESTS)
def test_convert_gpu_name(self, test_input, expected):
assert convert_gpu_name(test_input) == expected

0 comments on commit b615a12

Please sign in to comment.