Skip to content

Commit

Permalink
update cluster, node pool patch methods
Browse files Browse the repository at this point in the history
  • Loading branch information
zubenkoivan committed Dec 11, 2024
1 parent a064c00 commit 5076193
Show file tree
Hide file tree
Showing 5 changed files with 482 additions and 284 deletions.
16 changes: 8 additions & 8 deletions neuro_config_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
AWSCloudProvider,
AWSCredentials,
AWSStorage,
AWSStorageOptions,
AzureCloudProvider,
AzureCredentials,
AzureReplicationType,
AzureStorage,
AzureStorageOptions,
AzureStorageTier,
BucketsConfig,
CloudProvider,
Expand All @@ -36,7 +34,6 @@
GoogleCloudProvider,
GoogleFilestoreTier,
GoogleStorage,
GoogleStorageOptions,
GrafanaCredentials,
HelmRegistryConfig,
IdleJobConfig,
Expand All @@ -52,6 +49,10 @@
OnPremCloudProvider,
OpenStackCredentials,
OrchestratorConfig,
PatchClusterRequest,
PatchNodePoolResourcesRequest,
PatchNodePoolSizeRequest,
PatchOrchestratorConfigRequest,
RegistryConfig,
ResourcePoolType,
ResourcePreset,
Expand All @@ -61,7 +62,6 @@
Storage,
StorageConfig,
StorageInstance,
StorageOptions,
TPUPreset,
TPUResource,
VCDCloudProvider,
Expand All @@ -79,12 +79,10 @@
"AWSCloudProvider",
"AWSCredentials",
"AWSStorage",
"AWSStorageOptions",
"AzureCloudProvider",
"AzureCredentials",
"AzureReplicationType",
"AzureStorage",
"AzureStorageOptions",
"AzureStorageTier",
"BucketsConfig",
"CloudProvider",
Expand All @@ -106,7 +104,6 @@
"GoogleCloudProvider",
"GoogleFilestoreTier",
"GoogleStorage",
"GoogleStorageOptions",
"GrafanaCredentials",
"HelmRegistryConfig",
"IdleJobConfig",
Expand All @@ -122,6 +119,10 @@
"OnPremCloudProvider",
"OpenStackCredentials",
"OrchestratorConfig",
"PatchClusterRequest",
"PatchNodePoolResourcesRequest",
"PatchNodePoolSizeRequest",
"PatchOrchestratorConfigRequest",
"RegistryConfig",
"ResourcePoolType",
"ResourcePreset",
Expand All @@ -131,7 +132,6 @@
"Storage",
"StorageConfig",
"StorageInstance",
"StorageOptions",
"TPUPreset",
"TPUResource",
"VCDCloudProvider",
Expand Down
85 changes: 13 additions & 72 deletions neuro_config_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import abc
import logging
import sys
from collections.abc import AsyncIterator, Mapping, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from dataclasses import dataclass
Expand All @@ -14,33 +13,18 @@
from yarl import URL

from .entities import (
BucketsConfig,
CloudProviderOptions,
CloudProviderType,
Cluster,
CredentialsConfig,
DisksConfig,
DNSConfig,
EnergyConfig,
IngressConfig,
MetricsConfig,
MonitoringConfig,
NodePool,
NotificationType,
OrchestratorConfig,
RegistryConfig,
PatchClusterRequest,
PatchNodePoolResourcesRequest,
PatchNodePoolSizeRequest,
ResourcePreset,
SecretsConfig,
StorageConfig,
)
from .factories import EntityFactory, PayloadFactory

if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
# why not backports.zoneinfo: https://github.com/pganssle/zoneinfo/issues/125
from backports.zoneinfo._zoneinfo import ZoneInfo

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -170,56 +154,10 @@ async def create_blank_cluster(
return await self.get_cluster(name)

async def patch_cluster(
self,
name: str,
*,
credentials: CredentialsConfig | None = None,
storage: StorageConfig | None = None,
registry: RegistryConfig | None = None,
orchestrator: OrchestratorConfig | None = None,
monitoring: MonitoringConfig | None = None,
secrets: SecretsConfig | None = None,
metrics: MetricsConfig | None = None,
disks: DisksConfig | None = None,
buckets: BucketsConfig | None = None,
ingress: IngressConfig | None = None,
dns: DNSConfig | None = None,
timezone: ZoneInfo | None = None,
energy: EnergyConfig | None = None,
token: str | None = None,
self, name: str, request: PatchClusterRequest, *, token: str | None = None
) -> Cluster:
path = self._endpoints.cluster(name)
payload: dict[str, Any] = {}
if credentials:
payload["credentials"] = self._payload_factory.create_credentials(
credentials
)
if storage:
payload["storage"] = self._payload_factory.create_storage(storage)
if registry:
payload["registry"] = self._payload_factory.create_registry(registry)
if orchestrator:
payload["orchestrator"] = self._payload_factory.create_orchestrator(
orchestrator
)
if monitoring:
payload["monitoring"] = self._payload_factory.create_monitoring(monitoring)
if secrets:
payload["secrets"] = self._payload_factory.create_secrets(secrets)
if metrics:
payload["metrics"] = self._payload_factory.create_metrics(metrics)
if disks:
payload["disks"] = self._payload_factory.create_disks(disks)
if buckets:
payload["buckets"] = self._payload_factory.create_buckets(buckets)
if ingress:
payload["ingress"] = self._payload_factory.create_ingress(ingress)
if dns:
payload["dns"] = self._payload_factory.create_dns(dns)
if timezone:
payload["timezone"] = str(timezone)
if energy:
payload["energy"] = self._payload_factory.create_energy(energy)
payload = self._payload_factory.create_patch_cluster_request(request)
async with self._request(
"PATCH", path, headers=self._create_headers(token=token), json=payload
) as resp:
Expand Down Expand Up @@ -393,16 +331,19 @@ async def patch_node_pool(
self,
cluster_name: str,
node_pool_name: str,
request: PatchNodePoolSizeRequest | PatchNodePoolResourcesRequest,
*,
idle_size: int | None = None,
start_deployment: bool = True,
token: str | None = None,
) -> Cluster:
path = self._endpoints.node_pool(cluster_name, node_pool_name)
payload: dict[str, Any] = {}
if idle_size is not None:
payload["idle_size"] = idle_size
payload = self._payload_factory.create_patch_node_pool_request(request)
async with self._request(
"PATCH", path, headers=self._create_headers(token=token), json=payload
"PATCH",
path,
params={"start_deployment": str(start_deployment).lower()},
headers=self._create_headers(token=token),
json=payload,
) as response:
resp_payload = await response.json()
return self._entity_factory.create_cluster(resp_payload)
Expand Down
110 changes: 71 additions & 39 deletions neuro_config_client/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def is_vcd(self) -> bool:
class CloudProviderOptions:
type: CloudProviderType
node_pools: list[NodePoolOptions]
storages: list[StorageOptions]


@dataclass(frozen=True)
Expand All @@ -59,41 +58,13 @@ class VCDCloudProviderOptions(CloudProviderOptions):

@dataclass(frozen=True)
class NodePoolOptions:
id: str
machine_type: str
cpu: float
available_cpu: float
memory: int
available_memory: int
gpu: int | None = None
gpu_model: str | None = None


@dataclass(frozen=True)
class StorageOptions:
id: str


@dataclass(frozen=True)
class GoogleStorageOptions(StorageOptions):
tier: GoogleFilestoreTier
min_capacity: int
max_capacity: int


@dataclass(frozen=True)
class AWSStorageOptions(StorageOptions):
performance_mode: EFSPerformanceMode
throughput_mode: EFSThroughputMode
provisioned_throughput_mibps: int | None = None


@dataclass(frozen=True)
class AzureStorageOptions(StorageOptions):
tier: AzureStorageTier
replication_type: AzureReplicationType
min_file_share_size: int
max_file_share_size: int
available_cpu: float | None = None
available_memory: int | None = None
nvidia_gpu: int | None = None
nvidia_gpu_model: str | None = None


class NodeRole(str, enum.Enum):
Expand All @@ -105,7 +76,6 @@ class NodeRole(str, enum.Enum):
@dataclass(frozen=True)
class NodePool:
name: str
id: str | None = None
role: NodeRole = NodeRole.PLATFORM_JOB

min_size: int = 0
Expand All @@ -119,6 +89,7 @@ class NodePool:
available_memory: int | None = None

disk_size: int | None = None
available_disk_size: int | None = None
disk_type: str | None = None

nvidia_gpu: int | None = None
Expand All @@ -142,6 +113,35 @@ class NodePool:
cpu_max_watts: float = 0.0


@dataclass(frozen=True)
class PatchNodePoolSizeRequest:
min_size: int | None = None
max_size: int | None = None
idle_size: int | None = None


@dataclass(frozen=True)
class PatchNodePoolResourcesRequest:
cpu: float
available_cpu: float
memory: int
available_memory: int
disk_size: int
available_disk_size: int

nvidia_gpu: int | None = None
nvidia_gpu_model: str | None = None
amd_gpu: int | None = None
amd_gpu_model: str | None = None
intel_gpu: int | None = None
intel_gpu_model: str | None = None

machine_type: str | None = None

min_size: int | None = None
max_size: int | None = None


@dataclass(frozen=True)
class StorageInstance:
name: str
Expand Down Expand Up @@ -183,7 +183,6 @@ class EFSThroughputMode(str, enum.Enum):

@dataclass(frozen=True)
class AWSStorage(Storage):
id: str
description: str
performance_mode: EFSPerformanceMode
throughput_mode: EFSThroughputMode
Expand Down Expand Up @@ -215,7 +214,6 @@ class GoogleFilestoreTier(str, enum.Enum):

@dataclass(frozen=True)
class GoogleStorage(Storage):
id: str
description: str
tier: GoogleFilestoreTier

Expand Down Expand Up @@ -254,7 +252,6 @@ class AzureReplicationType(str, enum.Enum):

@dataclass(frozen=True)
class AzureStorage(Storage):
id: str
description: str
tier: AzureStorageTier
replication_type: AzureReplicationType
Expand Down Expand Up @@ -518,9 +515,11 @@ class ResourcePoolType:

@dataclass(frozen=True)
class Resources:
cpu_m: int
cpu: float
memory: int
gpu: int = 0
nvidia_gpu: int = 0
amd_gpu: int = 0
intel_gpu: int = 0


@dataclass(frozen=True)
Expand Down Expand Up @@ -552,6 +551,22 @@ class OrchestratorConfig:
idle_jobs: Sequence[IdleJobConfig] = ()


@dataclass
class PatchOrchestratorConfigRequest:
job_hostname_template: str | None = None
job_internal_hostname_template: str | None = None
job_fallback_hostname: str | None = None
job_schedule_timeout_s: float | None = None
job_schedule_scale_up_timeout_s: float | None = None
is_http_ingress_secure: bool | None = None
resource_pool_types: Sequence[ResourcePoolType] | None = None
resource_presets: Sequence[ResourcePreset] | None = None
allow_privileged_mode: bool | None = None
allow_job_priority: bool | None = None
pre_pull_images: Sequence[str] | None = None
idle_jobs: Sequence[IdleJobConfig] | None = None


@dataclass
class ARecord:
name: str
Expand Down Expand Up @@ -618,3 +633,20 @@ class Cluster:
buckets: BucketsConfig | None = None
ingress: IngressConfig | None = None
energy: EnergyConfig | None = None


@dataclass(frozen=True)
class PatchClusterRequest:
credentials: CredentialsConfig | None = None
storage: StorageConfig | None = None
registry: RegistryConfig | None = None
orchestrator: PatchOrchestratorConfigRequest | None = None
monitoring: MonitoringConfig | None = None
secrets: SecretsConfig | None = None
metrics: MetricsConfig | None = None
disks: DisksConfig | None = None
buckets: BucketsConfig | None = None
ingress: IngressConfig | None = None
dns: DNSConfig | None = None
timezone: ZoneInfo | None = None
energy: EnergyConfig | None = None
Loading

0 comments on commit 5076193

Please sign in to comment.