Skip to content

Commit

Permalink
Allow specifying Azure resource_group (#2288)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored Feb 11, 2025
1 parent 34ac563 commit a495d67
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 9 deletions.
3 changes: 3 additions & 0 deletions docs/docs/concepts/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ There are two ways to configure Azure: using a client secret or using the defaul
}
```

The `"Microsoft.Resources/subscriptions/resourceGroups/write"` permission is not required
if [`resource_group`](/docs/reference/server/config.yml/#azure) is specified.

??? info "VPC"
By default, `dstack` creates new Azure networks and subnets for every configured region.
It's possible to use custom networks by specifying `vpc_ids`:
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/models/backends/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class AzureConfigInfo(CoreModel):
type: Literal["azure"] = "azure"
tenant_id: str
subscription_id: str
resource_group: Optional[str] = None
locations: Optional[List[str]] = None
vpc_ids: Optional[Dict[str, str]] = None
public_ips: Optional[bool] = None
Expand Down Expand Up @@ -48,6 +49,7 @@ class AzureConfigInfoWithCredsPartial(CoreModel):
creds: Optional[AnyAzureCreds]
tenant_id: Optional[str]
subscription_id: Optional[str]
resource_group: Optional[str]
locations: Optional[List[str]]
vpc_ids: Optional[Dict[str, str]]
public_ips: Optional[bool]
Expand All @@ -63,4 +65,4 @@ class AzureConfigValues(CoreModel):


class AzureStoredConfig(AzureConfigInfo):
resource_group: str
resource_group: str = ""
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional, Tuple

import azure.core.exceptions
from azure.core.credentials import TokenCredential
from azure.mgmt import network as network_mgmt
from azure.mgmt import resource as resource_mgmt
Expand Down Expand Up @@ -154,16 +155,17 @@ def create_backend(
if is_core_model_instance(config.creds, AzureClientCreds):
self._set_client_creds_tenant_id(config.creds, config.tenant_id)
credential, _ = auth.authenticate(config.creds)
resource_group = self._create_resource_group(
credential=credential,
subscription_id=config.subscription_id,
location=MAIN_LOCATION,
project_name=project.name,
)
if config.resource_group is None:
config.resource_group = self._create_resource_group(
credential=credential,
subscription_id=config.subscription_id,
location=MAIN_LOCATION,
project_name=project.name,
)
self._create_network_resources(
credential=credential,
subscription_id=config.subscription_id,
resource_group=resource_group,
resource_group=config.resource_group,
locations=config.locations,
create_default_network=config.vpc_ids is None,
)
Expand All @@ -172,7 +174,6 @@ def create_backend(
type=self.TYPE.value,
config=AzureStoredConfig(
**AzureConfigInfo.__response__.parse_obj(config).dict(),
resource_group=resource_group,
).json(),
auth=DecryptedString(plaintext=AzureCreds.parse_obj(config.creds).__root__.json()),
)
Expand Down Expand Up @@ -322,6 +323,7 @@ def _check_config(
self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential
):
self._check_tags_config(config)
self._check_resource_group(config=config, credential=credential)
self._check_vpc_config(config=config, credential=credential)

def _check_tags_config(self, config: AzureConfigInfoWithCredsPartial):
Expand All @@ -336,6 +338,18 @@ def _check_tags_config(self, config: AzureConfigInfoWithCredsPartial):
except BackendError as e:
raise ServerClientError(e.args[0])

def _check_resource_group(
self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential
):
if config.resource_group is None:
return
resource_manager = ResourceManager(
credential=credential,
subscription_id=config.subscription_id,
)
if not resource_manager.resource_group_exists(config.resource_group):
raise ServerClientError(f"Resource group {config.resource_group} not found")

def _check_vpc_config(
self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential
):
Expand Down Expand Up @@ -406,6 +420,18 @@ def create_resource_group(
)
return resource_group.name

def resource_group_exists(
self,
name: str,
) -> bool:
try:
self.resource_client.resource_groups.get(
resource_group_name=name,
)
except azure.core.exceptions.ResourceNotFoundError:
return False
return True


class NetworkManager:
def __init__(self, credential: TokenCredential, subscription_id: str):
Expand Down
9 changes: 9 additions & 0 deletions src/dstack/_internal/server/services/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ class AzureConfig(CoreModel):
type: Annotated[Literal["azure"], Field(description="The type of the backend")] = "azure"
tenant_id: Annotated[str, Field(description="The tenant ID")]
subscription_id: Annotated[str, Field(description="The subscription ID")]
resource_group: Annotated[
Optional[str],
Field(
description=(
"The resource group for resources created by `dstack`."
" If not specified, `dstack` will create a new resource group"
)
),
] = None
regions: Annotated[
Optional[List[str]],
Field(description="The list of Azure regions (locations). Omit to use all regions"),
Expand Down

0 comments on commit a495d67

Please sign in to comment.