Skip to content

Commit

Permalink
fix GCP catalog
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Oct 9, 2022
1 parent fdb56c8 commit 4704903
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions sky/clouds/service_catalog/gcp_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
For now this service catalog is manually coded. In the future it should be
queried from GCP API.
"""
from collections import defaultdict
import typing
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -30,12 +31,20 @@
# TODO(zongheng): fix A100 info directly in catalog.
# https://cloud.google.com/blog/products/compute/a2-vms-with-nvidia-a100-gpus-are-ga
# count -> vm type
_A100_INSTANCE_TYPES = {
_A100_INSTANCE_TYPE_DICTS = {
'A100': {
1: 'a2-highgpu-1g',
2: 'a2-highgpu-2g',
4: 'a2-highgpu-4g',
8: 'a2-highgpu-8g',
16: 'a2-megagpu-16g',
},
'A100-80GB': {
1: 'a2-ultragpu-1g',
2: 'a2-ultragpu-2g',
4: 'a2-ultragpu-4g',
8: 'a2-ultragpu-8g',
}
}

# Number of CPU cores per GPU based on the AWS setting.
Expand Down Expand Up @@ -167,10 +176,10 @@ def get_instance_type_for_accelerator(
if instance_list is None:
return None, fuzzy_candidate_list

if acc_name == 'A100':
if acc_name in _A100_INSTANCE_TYPE_DICTS:
# If A100 is used, host VM type must be A2.
# https://cloud.google.com/compute/docs/gpus#a100-gpus
return [_A100_INSTANCE_TYPES[acc_count]], []
return [_A100_INSTANCE_TYPE_DICTS[acc_name][acc_count]], []
if acc_name not in _NUM_ACC_TO_NUM_CPU:
acc_name = 'DEFAULT'

Expand Down Expand Up @@ -258,17 +267,17 @@ def list_accelerators(
results = common.list_accelerators_impl('GCP', _df, gpus_only, name_filter,
case_sensitive)

a100_infos = results.get('A100', None)
if a100_infos is None:
a100_infos = results.get('A100', []) + results.get('A100-80GB', [])
if not a100_infos:
return results

# Unlike other GPUs that can be attached to different sizes of N1 VMs,
# A100 GPUs can only be attached to fixed-size A2 VMs.
# Thus, we can show their exact cost including the host VM prices.
new_infos = []
new_infos = defaultdict(list)
for info in a100_infos:
assert pd.isna(info.instance_type) and pd.isna(info.memory), a100_infos
a100_host_vm_type = _A100_INSTANCE_TYPES[info.accelerator_count]
a100_host_vm_type = _A100_INSTANCE_TYPE_DICTS[info.accelerator_name][info.accelerator_count]
df = _df[_df['InstanceType'] == a100_host_vm_type]
cpu_count = df['vCPUs'].iloc[0]
memory = df['MemoryGiB'].iloc[0]
Expand All @@ -280,7 +289,7 @@ def list_accelerators(
a100_host_vm_type,
None,
use_spot=True)
new_infos.append(
new_infos[info.accelerator_name].append(
info._replace(
instance_type=a100_host_vm_type,
cpu_count=cpu_count,
Expand All @@ -289,7 +298,7 @@ def list_accelerators(
price=info.price + vm_price,
spot_price=info.spot_price + vm_spot_price,
))
results['A100'] = new_infos
results.update(new_infos)
return results


Expand Down Expand Up @@ -382,8 +391,8 @@ def check_accelerator_attachable_to_host(instance_type: str,
assert instance_type == 'TPU-VM' or instance_type.startswith('n1-')
return

if acc_name == 'A100':
valid_counts = list(_A100_INSTANCE_TYPES.keys())
if acc_name in _A100_INSTANCE_TYPE_DICTS:
valid_counts = list(_A100_INSTANCE_TYPE_DICTS[acc_name].keys())
else:
valid_counts = list(_NUM_ACC_TO_MAX_CPU_AND_MEMORY[acc_name].keys())
if acc_count not in valid_counts:
Expand All @@ -392,8 +401,8 @@ def check_accelerator_attachable_to_host(instance_type: str,
f'{acc_name}:{acc_count} is not launchable on GCP. '
f'The valid {acc_name} counts are {valid_counts}.')

if acc_name == 'A100':
a100_instance_type = _A100_INSTANCE_TYPES[acc_count]
if acc_name in _A100_INSTANCE_TYPE_DICTS:
a100_instance_type = _A100_INSTANCE_TYPE_DICTS[acc_name][acc_count]
if instance_type != a100_instance_type:
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesMismatchError(
Expand Down

0 comments on commit 4704903

Please sign in to comment.