From fcfb7f915e9a49755fb470516e201bd730cd6e8e Mon Sep 17 00:00:00 2001 From: Alexandra Belousov Date: Mon, 2 Sep 2024 11:37:56 +0300 Subject: [PATCH] introduce cluster list cli cmd --- runhouse/main.py | 103 ++++++++++++++++++++++++++++++++++++++++++++++ runhouse/utils.py | 12 ++++++ 2 files changed, 115 insertions(+) diff --git a/runhouse/main.py b/runhouse/main.py index 6ff5813009..0cc345bddf 100644 --- a/runhouse/main.py +++ b/runhouse/main.py @@ -1,3 +1,4 @@ +import asyncio import datetime import importlib import logging @@ -9,6 +10,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional +import httpx + import ray import requests @@ -17,6 +20,7 @@ import typer import yaml from rich.console import Console +from rich.table import Table import runhouse as rh @@ -41,9 +45,18 @@ kill_actors, ) +from runhouse.utils import get_status_color + # create an explicit Typer application app = typer.Typer(add_completion=False) +# creating a cluster app so we could create subcommands of cluster (i.e runhouse cluster list) +italic_bold_ansi = "\x1B[3m\x1B[1m" +reset_format = "\x1B[0m" +cluster_app = typer.Typer( + help=f"Cluster information commands. For more info run {italic_bold_ansi}runhouse cluster --help{reset_format}" +) + # For printing with typer console = Console() @@ -512,6 +525,9 @@ def _print_status(status_data: dict, current_cluster: Cluster) -> None: _print_envs_info(env_servlet_processes, current_cluster) +############################### +# Cluster CLI commands +############################### @app.command() def status( cluster_name: str = typer.Argument( @@ -578,6 +594,93 @@ def status( _print_status(cluster_status, current_cluster) +async def aget_clusters_from_den(): + httpx_client = httpx.AsyncClient() + + get_clusters_params = {"resource_type": "cluster", "folder": rns_client.username} + clusters_in_den_resp = await httpx_client.get( + f"{rns_client.api_server_url}/resource", + params=get_clusters_params, + headers=rns_client.request_headers(), + ) + + return clusters_in_den_resp + + +def get_clusters_from_den(): + return asyncio.run(aget_clusters_from_den()) + + +@cluster_app.command("list") +def cluster_list(): + """Load Runhouse clusters""" + import sky + + # logged out case + if not rh.configs.token: + # TODO [SB]: adjust msg formatting (coloring etc) + sky_cli_command_formatted = f"{italic_bold_ansi}sky status -r{reset_format}" # will be printed bold and italic + console.print( + f"This feature is available only for Den users. Please run {sky_cli_command_formatted} to get on-demand cluster(s) information or sign-up Den." + ) + return + + on_demand_clusters_sky = sky.status(refresh=True) + + clusters_in_den_resp = get_clusters_from_den() + + if clusters_in_den_resp.status_code != 200: + logger.error(f"Failed to load {rns_client.username}'s clusters from Den") + clusters_in_den = [] + else: + clusters_in_den = clusters_in_den_resp.json().get("data") + + clusters_in_den_names = [cluster.get("name") for cluster in clusters_in_den] + + if not on_demand_clusters_sky and not clusters_in_den: + console.print("No existing clusters.") + + if on_demand_clusters_sky: + # getting the on-demand clusters that are not saved in den. + on_demand_clusters_sky = [ + cluster + for cluster in on_demand_clusters_sky + if f'/{rns_client.username}/{cluster.get("name")}' + not in clusters_in_den_names + ] + + total_clusters = len(clusters_in_den) + table_title = f"[bold cyan]{rns_client.username}'s Clusters (Total: {total_clusters})[/bold cyan]" + + table = Table(title=table_title) + + # Add columns to the table + table.add_column("Name", justify="left", no_wrap=True) + table.add_column("Cluster Type", justify="center", no_wrap=True) + table.add_column("Status", justify="left") + + for den_cluster in clusters_in_den: + # get just name, not full rns address. reset is used so the name will be printed all in white. + cluster_name = f'[reset]{den_cluster.get("name").split("/")[-1]}' + cluster_type = den_cluster.get("data").get("resource_subtype") + cluster_status = ( + den_cluster.get("status") if den_cluster.get("status") else "unknown" + ) + cluster_status_colored = get_status_color(cluster_status) + table.add_row(cluster_name, cluster_type, cluster_status_colored) + + console.print(table) + + if len(on_demand_clusters_sky) > 0: + console.print( + f"There are {len(on_demand_clusters_sky)} live clusters that are not saved in Den. To get information about them, please run [bold italic]sky status -r[/bold italic]." + ) + + +# Register the 'cluster' command group with the main runhouse application +app.add_typer(cluster_app, name="cluster") + + def load_cluster(cluster_name: str): """Load a cluster from RNS into the local environment, e.g. to be able to ssh.""" c = cluster(name=cluster_name) diff --git a/runhouse/utils.py b/runhouse/utils.py index 833a28f4fd..761bdb8e79 100644 --- a/runhouse/utils.py +++ b/runhouse/utils.py @@ -652,3 +652,15 @@ def get_gpu_usage(collected_gpus_info: dict, servlet_type: ServletType): gpu_usage["utilization_percent"] = gpu_utilization_percent return gpu_usage + + +class StatusColors(str, Enum): + RUNNING = "[green]Running[/green]" + SERVER_DOWN = "[orange1]Runhouse server down[/orange1]" + TERMINATED = "[red]Terminated[/red]" + UNKNOWN = "Unknown" + LOCAL_CLUSTER = "[bright_yellow]Local cluster[/bright_yellow]" + + +def get_status_color(status: str): + return getattr(StatusColors, status.upper()).value