Skip to content

Commit

Permalink
fix 1) gpu detection in multi-gpu setup 2) oem A100 cannot be matched…
Browse files Browse the repository at this point in the history
… to correct TVM target (#58)
  • Loading branch information
Qubitium authored Jun 21, 2024
1 parent d589a79 commit cca477e
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions python/bitblas/utils/target_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"where <target> is one of the available targets can be found in the output of `tools/get_available_targets.py`."
)

def get_gpu_model_from_nvidia_smi():
def get_gpu_model_from_nvidia_smi(gpu_id: int = 0):
"""
Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU.
Expand All @@ -26,7 +26,7 @@ def get_gpu_model_from_nvidia_smi():
try:
# Execute nvidia-smi command to get the GPU name
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"],
["nvidia-smi", f"--id={gpu_id}", "--query-gpu=gpu_name", "--format=csv,noheader"],
encoding="utf-8",
).strip()
except subprocess.CalledProcessError as e:
Expand Down Expand Up @@ -62,7 +62,7 @@ def get_all_nvidia_targets() -> List[str]:
return [tag for tag in all_tags if "nvidia" in tag]


def auto_detect_nvidia_target() -> str:
def auto_detect_nvidia_target(gpu_id: int = 0) -> str:
"""
Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target.
Expand All @@ -78,6 +78,13 @@ def auto_detect_nvidia_target() -> str:
nvidia_tags = [tag for tag in all_tags if "nvidia" in tag]

# Get the current GPU model and find the best matching target
gpu_model = get_gpu_model_from_nvidia_smi()
gpu_model = get_gpu_model_from_nvidia_smi(gpu_id=gpu_id)

# TODO: move to a more res-usable device remapping util method
# compat: Nvidia makes several oem (non-public) versions of A100 and perhaps other models that
# do not have clearly defined TVM matching target so we need to manually map them to the correct one.
if gpu_model == "NVIDIA PG506-230":
gpu_model = "NVIDIA A100"

target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda"
return target

0 comments on commit cca477e

Please sign in to comment.