Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DSTACK_NODE_RANK #1189

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func (ex *RunExecutor) SetRunnerState(state string) {
}

func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error {
node_rank := ex.clusterInfo.GPUSPerJob
node_rank := ex.jobSpec.JobNum
nodes_num := ex.jobSpec.JobsPerReplica
gpus_per_node_num := ex.clusterInfo.GPUSPerJob
gpus_num := nodes_num * gpus_per_node_num
Expand Down
19 changes: 11 additions & 8 deletions src/dstack/_internal/core/backends/remote/provisioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def sftp_upload(client: paramiko.SSHClient, path: str, body: str) -> None:
sftp.putfo(io.BytesIO(body.encode()), path)
sftp.close()
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError() from e
raise ProvisioningError(f"sft_upload failed: {e}") from e


def upload_envs(client: paramiko.SSHClient, working_dir: str, envs: Dict[str, str]) -> None:
Expand All @@ -50,7 +50,7 @@ def upload_envs(client: paramiko.SSHClient, working_dir: str, envs: Dict[str, st
f"The command 'upload_envs' didn't work. stdout: {out}, stderr: {err}"
)
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError() from e
raise ProvisioningError(f"upload_envs failed: {e}") from e


def run_pre_start_commands(
Expand All @@ -68,7 +68,7 @@ def run_pre_start_commands(
f"The command 'authorized_keys' didn't work. stdout: {out}, stderr: {err}"
)
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError() from e
raise ProvisioningError(f"upload authorized_keys failed: {e}") from e

script = " && ".join(shim_pre_start_commands)
try:
Expand All @@ -80,7 +80,7 @@ def run_pre_start_commands(
f"The command 'run_pre_start_commands' didn't work. stdout: {out}, stderr: {err}"
)
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError() from e
raise ProvisioningError(f"run_pre-start_commands failed: {e}") from e


def run_shim_as_systemd_service(client: paramiko.SSHClient, working_dir: str, dev: bool) -> None:
Expand Down Expand Up @@ -122,7 +122,7 @@ def run_shim_as_systemd_service(client: paramiko.SSHClient, working_dir: str, de
f"The command 'run_shim_as_systemd_service' didn't work. stdout: {out}, stderr: {err}"
)
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError() from e
raise ProvisioningError(f"run_shim_as_systemd failed: {e}") from e


def check_dstack_shim_service(client: paramiko.SSHClient):
Expand All @@ -145,7 +145,7 @@ def get_host_info(client: paramiko.SSHClient, working_dir: str) -> Dict[str, Any
)
err = stderr.read().decode().strip()
if err:
logger.debug("Cannot read `host_info.json`: %s", err)
logger.debug("Retry after error: %s", err)
time.sleep(iter_delay)
continue
except (paramiko.SSHException, OSError) as e:
Expand Down Expand Up @@ -184,7 +184,7 @@ def get_shim_healthcheck(client: paramiko.SSHClient) -> str:
continue
return out
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError() from e
raise ProvisioningError(f"get_shim_healthcheck failed: {e}") from e


def host_info_to_instance_type(host_info: Dict[str, Any]) -> InstanceType:
Expand Down Expand Up @@ -226,9 +226,12 @@ def get_paramiko_connection(
timeout=SSH_CONNECT_TIMEOUT,
)
except paramiko.AuthenticationException:
logger.debug(
f'Authentication faild to connect to "{conn_url}" and {pkey.fingerprint}'
)
continue # try next key
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError() from e
raise ProvisioningError(f"Connect failed: {e}") from e
else:
yield client
return
Expand Down
6 changes: 3 additions & 3 deletions src/dstack/_internal/core/models/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ class Resources(CoreModel):
def pretty_format(self) -> str:
resources = {
"cpus": self.cpus,
"memory": f"{self.memory_mib / 1024:g}GB",
"disk_size": f"{self.disk.size_mib / 1024:g}GB",
"memory": f"{self.memory_mib / 1024:.0f}GB",
"disk_size": f"{self.disk.size_mib / 1024:.1f}GB",
}
if self.gpus:
gpu = self.gpus[0]
resources.update(
gpu_name=gpu.name,
gpu_count=len(self.gpus),
gpu_memory=f"{gpu.memory_mib / 1024:g}GB",
gpu_memory=f"{gpu.memory_mib / 1024:.0f}GB",
)
return pretty_resources(**resources)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ async def add_remote(instance_id: UUID) -> None:
result = await asyncio.wait_for(future, timeout=deploy_timeout)
health, host_info = result
except (asyncio.TimeoutError, TimeoutError) as e:
raise ProvisioningError() from e
raise ProvisioningError(f"Deploy timeout {e}") from e
except Exception as e:
logger.debug("deploy_instance raise an error: %s", e)
raise ProvisioningError() from e
raise ProvisioningError(f"Deploy instance raise an error {e}") from e
else:
logger.info(
"The instance %s (%s) was successfully added",
Expand All @@ -232,7 +232,11 @@ async def add_remote(instance_id: UUID) -> None:
)

except ProvisioningError as e:
logger.warning("Provisioning could not be completed because of the error: %s", e)
logger.warning(
"Provisioning the instance '%s' could not be completed because of the error: %s",
instance.name,
e,
)
instance.status = InstanceStatus.PENDING
instance.last_retry_at = get_current_datetime()
await session.commit()
Expand Down
6 changes: 5 additions & 1 deletion src/dstack/_internal/server/routers/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dstack._internal.core.models.pools as models
import dstack._internal.server.schemas.pools as schemas
import dstack._internal.server.services.pools as pools
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.pools import Instance
from dstack._internal.server.db import get_session
from dstack._internal.server.models import ProjectModel, UserModel
Expand Down Expand Up @@ -82,6 +83,9 @@ async def add_instance(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> Instance:
if not body.host.strip() or not body.ssh_user.strip() or not body.ssh_keys:
raise ConfigurationError("Host, user or ssh keys are empty")

_, project = user_project
result = await pools.add_remote(
session,
Expand All @@ -90,7 +94,7 @@ async def add_instance(
instance_name=body.instance_name,
region=body.region,
host=body.host,
port=body.port,
port=body.port or 22,
ssh_user=body.ssh_user,
ssh_keys=body.ssh_keys,
)
Expand Down
17 changes: 12 additions & 5 deletions src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,15 @@ async def add_remote(
ssh_user: str,
ssh_keys: List[SSHKey],
) -> Instance:
# Check instance in all instances
pools = await list_project_pool_models(session, project)
for pool in pools:
for instance in pool.instances:
if instance.remote_connection_info is not None:
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
if rci.host == host and rci.port == port and rci.ssh_user == ssh_user:
return instance_model_to_instance(instance)

pool_model = await get_or_create_pool_by_name(session, project, pool_name)
pool_model_name = pool_model.name
if instance_name is None:
Expand Down Expand Up @@ -288,11 +297,9 @@ async def add_remote(
availability=InstanceAvailability.AVAILABLE,
)

ssh_connection_info = None
if ssh_user and ssh_keys:
ssh_connection_info = RemoteConnectionInfo(
host=host, port=port, ssh_user=ssh_user, ssh_keys=ssh_keys
).json()
ssh_connection_info = RemoteConnectionInfo(
host=host, port=port, ssh_user=ssh_user, ssh_keys=ssh_keys
).json()

im = InstanceModel(
name=instance_name,
Expand Down
Loading