From e1fa655c4a3d051c27e45488d99a19ee5dda2d38 Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Wed, 26 Jun 2024 16:45:30 +0800 Subject: [PATCH] fix gpu model missing from tvm target remap (#61) --- python/bitblas/utils/target_detector.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/bitblas/utils/target_detector.py b/python/bitblas/utils/target_detector.py index 7dc220f6a6ee..d4baafdeeaf1 100644 --- a/python/bitblas/utils/target_detector.py +++ b/python/bitblas/utils/target_detector.py @@ -16,6 +16,13 @@ "where is one of the available targets can be found in the output of `tools/get_available_targets.py`." ) +# Nvidia produces non-public oem gpu models that are part of drivers but not mapped to correct tvm target +# Remap list to match the oem model name to the closest public model name +NVIDIA_GPU_REMAP = { + "NVIDIA PG506-230": "NVIDIA A100", + "NVIDIA PG506-232": "NVIDIA A100", +} + def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): """ Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU. @@ -88,11 +95,9 @@ def auto_detect_nvidia_target(gpu_id: int = 0) -> str: # Get the current GPU model and find the best matching target gpu_model = get_gpu_model_from_nvidia_smi(gpu_id=gpu_id) - # TODO: move to a more res-usable device remapping util method - # compat: Nvidia makes several oem (non-public) versions of A100 and perhaps other models that - # do not have clearly defined TVM matching target so we need to manually map them to the correct one. - if gpu_model == "NVIDIA PG506-230": - gpu_model = "NVIDIA A100" + # Compat: remap oem devices to their correct non-oem model names for tvm target + if gpu_model in NVIDIA_GPU_REMAP: + gpu_model = NVIDIA_GPU_REMAP[gpu_model] target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda" return target