Skip to content

Commit

Permalink
Support provisioning instances without public IPs on AWS (#1203)
Browse files Browse the repository at this point in the history
* Support provisioning instances without public IPs on AWS

* Fix tests
  • Loading branch information
r4victor authored May 8, 2024
1 parent 0ae6b9e commit c6315ab
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 12 deletions.
29 changes: 26 additions & 3 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def create_instance(
ec2 = self.session.resource("ec2", region_name=instance_offer.region)
ec2_client = self.session.client("ec2", region_name=instance_offer.region)
iam_client = self.session.client("iam", region_name=instance_offer.region)
allocate_public_ip = self.config.allocate_public_ips

tags = [
{"Key": "Name", "Value": instance_config.instance_name},
Expand All @@ -117,6 +118,7 @@ def create_instance(
ec2_client=ec2_client,
config=self.config,
region=instance_offer.region,
allocate_public_ip=allocate_public_ip,
)
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
response = ec2.create_instances(
Expand All @@ -140,6 +142,7 @@ def create_instance(
),
spot=instance_offer.instance.resources.spot,
subnet_id=subnet_id,
allocate_public_ip=allocate_public_ip,
)
)
instance = response[0]
Expand All @@ -149,11 +152,16 @@ def create_instance(
ec2_client.cancel_spot_instance_requests(
SpotInstanceRequestIds=[instance.spot_instance_request_id]
)
if allocate_public_ip:
hostname = instance.public_ip_address
else:
hostname = instance.private_ip_address
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=instance.instance_id,
hostname=instance.public_ip_address,
public_ip_enabled=allocate_public_ip,
hostname=hostname,
internal_ip=instance.private_ip_address,
region=instance_offer.region,
price=instance_offer.price,
Expand Down Expand Up @@ -247,6 +255,7 @@ def get_vpc_id_subnet_id_or_error(
ec2_client: botocore.client.BaseClient,
config: AWSConfig,
region: str,
allocate_public_ip: bool,
) -> Tuple[str, str]:
if config.vpc_ids is not None:
vpc_id = config.vpc_ids.get(region)
Expand All @@ -259,6 +268,7 @@ def get_vpc_id_subnet_id_or_error(
subnet_id = aws_resources.get_subnet_id_for_vpc(
ec2_client=ec2_client,
vpc_id=vpc_id,
allocate_public_ip=allocate_public_ip,
)
if subnet_id is not None:
return vpc_id, subnet_id
Expand All @@ -268,13 +278,15 @@ def get_vpc_id_subnet_id_or_error(
ec2_client=ec2_client,
vpc_name=config.vpc_name,
region=region,
allocate_public_ip=allocate_public_ip,
)


def _get_vpc_id_subnet_id_by_vpc_name_or_error(
ec2_client: botocore.client.BaseClient,
vpc_name: Optional[str],
region: str,
allocate_public_ip: bool,
) -> Tuple[str, str]:
if vpc_name is not None:
vpc_id = aws_resources.get_vpc_id_by_name(
Expand All @@ -290,9 +302,20 @@ def _get_vpc_id_subnet_id_by_vpc_name_or_error(
subnet_id = aws_resources.get_subnet_id_for_vpc(
ec2_client=ec2_client,
vpc_id=vpc_id,
allocate_public_ip=allocate_public_ip,
)
if subnet_id is not None:
return vpc_id, subnet_id
if vpc_name is not None:
raise ComputeError(f"Failed to find public subnet for VPC {vpc_name} in region {region}")
raise ComputeError(f"Failed to find public subnet for default VPC in region {region}")
if allocate_public_ip:
raise ComputeError(
f"Failed to find public subnet for VPC {vpc_name} in region {region}"
)
raise ComputeError(
f"Failed to find private subnet with NAT for VPC {vpc_name} in region {region}"
)
if allocate_public_ip:
raise ComputeError(f"Failed to find public subnet for default VPC in region {region}")
raise ComputeError(
f"Failed to find private subnet with NAT for default VPC in region {region}"
)
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/backends/aws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@

class AWSConfig(AWSStoredConfig, BackendConfig):
creds: AnyAWSCreds

@property
def allocate_public_ips(self) -> bool:
if self.public_ips is not None:
return self.public_ips
return True
62 changes: 55 additions & 7 deletions src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def create_instances_struct(
security_group_id: str,
spot: bool,
subnet_id: Optional[str] = None,
allocate_public_ip: bool = True,
) -> Dict[str, Any]:
struct = dict(
BlockDeviceMappings=[
Expand Down Expand Up @@ -230,7 +231,7 @@ def create_instances_struct(
if subnet_id is not None:
struct["NetworkInterfaces"] = [
{
"AssociatePublicIpAddress": True,
"AssociatePublicIpAddress": allocate_public_ip,
"DeviceIndex": 0,
"SubnetId": subnet_id,
"Groups": [security_group_id],
Expand Down Expand Up @@ -334,18 +335,31 @@ def get_vpc_by_vpc_id(ec2_client: botocore.client.BaseClient, vpc_id: str) -> Op
def get_subnet_id_for_vpc(
ec2_client: botocore.client.BaseClient,
vpc_id: str,
allocate_public_ip: bool,
) -> Optional[str]:
"""
If `allocate_public_ip` is True, returns a first public subnet found in the VPC.
If `allocate_public_ip` is False, returns a first subnet with NAT found in the VPC.
"""
subnets = _get_subnets_by_vpc_id(ec2_client=ec2_client, vpc_id=vpc_id)
if len(subnets) == 0:
return None
# Return first public subnet
for subnet in subnets:
subnet_id = subnet["SubnetId"]
is_public_subnet = _is_public_subnet(
ec2_client=ec2_client, vpc_id=vpc_id, subnet_id=subnet_id
)
if is_public_subnet:
return subnet_id
if allocate_public_ip:
is_public_subnet = _is_public_subnet(
ec2_client=ec2_client, vpc_id=vpc_id, subnet_id=subnet_id
)
if is_public_subnet:
return subnet_id
else:
subnet_behind_nat = _is_subnet_behind_nat(
ec2_client=ec2_client,
vpc_id=vpc_id,
subnet_id=subnet_id,
)
if subnet_behind_nat:
return subnet_id
return None


Expand Down Expand Up @@ -440,3 +454,37 @@ def _is_public_subnet(
return True

return False


def _is_subnet_behind_nat(
ec2_client: botocore.client.BaseClient,
vpc_id: str,
subnet_id: str,
) -> bool:
# Check explicitly associated route tables
response = ec2_client.describe_route_tables(
Filters=[{"Name": "association.subnet-id", "Values": [subnet_id]}]
)
for route_table in response["RouteTables"]:
for route in route_table["Routes"]:
if "NatGatewayId" in route and route["NatGatewayId"].startswith("nat-"):
return True

# Main route table controls the routing of all subnetes
# that are not explicitly associated with any other route table.
if len(response["RouteTables"]) > 0:
return False

# Check implicitly associated main route table
response = ec2_client.describe_route_tables(
Filters=[
{"Name": "association.main", "Values": ["true"]},
{"Name": "vpc-id", "Values": [vpc_id]},
]
)
for route_table in response["RouteTables"]:
for route in route_table["Routes"]:
if "NatGatewayId" in route and route["NatGatewayId"].startswith("nat-"):
return True

return False
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/backends/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class AWSConfigInfo(CoreModel):
regions: Optional[List[str]] = None
vpc_name: Optional[str] = None
vpc_ids: Optional[Dict[str, str]] = None
public_ips: Optional[bool] = None


class AWSAccessKeyCreds(CoreModel):
Expand Down Expand Up @@ -46,6 +47,7 @@ class AWSConfigInfoWithCredsPartial(CoreModel):
regions: Optional[List[str]]
vpc_name: Optional[str]
vpc_ids: Optional[Dict[str, str]]
public_ips: Optional[bool]


class AWSConfigValues(CoreModel):
Expand Down
6 changes: 5 additions & 1 deletion src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ class JobProvisioningData(CoreModel):
backend: BackendType
instance_type: InstanceType
instance_id: str
# hostname may not be set immediately after instance provisioning
# hostname may not be set immediately after instance provisioning.
# It is set to a public IP or, if public IPs are disabled, to a private IP.
hostname: Optional[str]
internal_ip: Optional[str]
# public_ip_enabled can used to distinguished instances with and without public IPs.
# hostname being None is not enough since it can be filled after provisioning.
public_ip_enabled: bool = True
region: str
price: float
username: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _check_vpc_config(self, session: Session, config: AWSConfigInfoWithCredsPart
regions = config.regions
if regions is None:
regions = DEFAULT_REGIONS
allocate_public_ip = config.public_ips if config.public_ips is not None else True
# The number of workers should be >= the number of regions
with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor:
futures = []
Expand All @@ -149,6 +150,7 @@ def _check_vpc_config(self, session: Session, config: AWSConfigInfoWithCredsPart
ec2_client=ec2_client,
config=AWSConfig.parse_obj(config),
region=region,
allocate_public_ip=allocate_public_ip,
)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
Expand Down
8 changes: 7 additions & 1 deletion src/dstack/_internal/server/services/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ class AWSConfig(CoreModel):
vpc_ids: Annotated[
Optional[Dict[str, str]], Field(description="The mapping from AWS regions to VPC IDs")
] = None
public_ips: Annotated[
Optional[bool],
Field(
description="A flag to enable/disable public IP assigning on instances. Defaults to `true`."
),
] = None
creds: AnyAWSCreds = Field(..., description="The credentials", discriminator="type")


Expand All @@ -76,8 +82,8 @@ class AzureConfig(CoreModel):
class CudoConfig(CoreModel):
type: Annotated[Literal["cudo"], Field(description="The type of backend")] = "cudo"
regions: Optional[List[str]] = None
creds: Annotated[AnyCudoCreds, Field(description="The credentials")]
project_id: Annotated[str, Field(description="The project ID")]
creds: Annotated[AnyCudoCreds, Field(description="The credentials")]


class DataCrunchConfig(CoreModel):
Expand Down
1 change: 1 addition & 0 deletions src/tests/_internal/server/routers/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,7 @@ async def test_returns_config_info(self, test_db, session: AsyncSession):
"regions": json.loads(backend.config)["regions"],
"vpc_name": None,
"vpc_ids": None,
"public_ips": None,
"creds": json.loads(backend.auth),
}

Expand Down
1 change: 1 addition & 0 deletions src/tests/_internal/server/routers/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def test_returns_projects(self, test_db, session: AsyncSession):
"regions": json.loads(backend.config)["regions"],
"vpc_name": None,
"vpc_ids": None,
"public_ips": None,
},
}
],
Expand Down

0 comments on commit c6315ab

Please sign in to comment.