diff --git a/python/bitblas/utils/target_detector.py b/python/bitblas/utils/target_detector.py index 33bf70d0a..b92409f40 100644 --- a/python/bitblas/utils/target_detector.py +++ b/python/bitblas/utils/target_detector.py @@ -16,7 +16,7 @@ "where is one of the available targets can be found in the output of `tools/get_available_targets.py`." ) -def get_gpu_model_from_nvidia_smi(): +def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): """ Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU. @@ -26,7 +26,7 @@ def get_gpu_model_from_nvidia_smi(): try: # Execute nvidia-smi command to get the GPU name output = subprocess.check_output( - ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"], + ["nvidia-smi", f"--id={gpu_id}", "--query-gpu=gpu_name", "--format=csv,noheader"], encoding="utf-8", ).strip() except subprocess.CalledProcessError as e: @@ -62,7 +62,7 @@ def get_all_nvidia_targets() -> List[str]: return [tag for tag in all_tags if "nvidia" in tag] -def auto_detect_nvidia_target() -> str: +def auto_detect_nvidia_target(gpu_id: int = 0) -> str: """ Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target. @@ -78,6 +78,13 @@ def auto_detect_nvidia_target() -> str: nvidia_tags = [tag for tag in all_tags if "nvidia" in tag] # Get the current GPU model and find the best matching target - gpu_model = get_gpu_model_from_nvidia_smi() + gpu_model = get_gpu_model_from_nvidia_smi(gpu_id=gpu_id) + + # TODO: move to a more res-usable device remapping util method + # compat: Nvidia makes several oem (non-public) versions of A100 and perhaps other models that + # do not have clearly defined TVM matching target so we need to manually map them to the correct one. + if gpu_model == "NVIDIA PG506-230": + gpu_model = "NVIDIA A100" + target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda" return target