Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] GPU detection in multigpu env and OEM A100 not matching TVM #58

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions python/bitblas/utils/target_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"where <target> is one of the available targets can be found in the output of `tools/get_available_targets.py`."
)

def get_gpu_model_from_nvidia_smi():
def get_gpu_model_from_nvidia_smi(gpu_id: int = 0):
"""
Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU.

Expand All @@ -26,7 +26,7 @@ def get_gpu_model_from_nvidia_smi():
try:
# Execute nvidia-smi command to get the GPU name
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"],
["nvidia-smi", f"--id={gpu_id}", "--query-gpu=gpu_name", "--format=csv,noheader"],
encoding="utf-8",
).strip()
except subprocess.CalledProcessError as e:
Expand Down Expand Up @@ -62,7 +62,7 @@ def get_all_nvidia_targets() -> List[str]:
return [tag for tag in all_tags if "nvidia" in tag]


def auto_detect_nvidia_target() -> str:
def auto_detect_nvidia_target(gpu_id: int = 0) -> str:
"""
Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target.

Expand All @@ -78,6 +78,13 @@ def auto_detect_nvidia_target() -> str:
nvidia_tags = [tag for tag in all_tags if "nvidia" in tag]

# Get the current GPU model and find the best matching target
gpu_model = get_gpu_model_from_nvidia_smi()
gpu_model = get_gpu_model_from_nvidia_smi(gpu_id=gpu_id)

# TODO: move to a more res-usable device remapping util method
# compat: Nvidia makes several oem (non-public) versions of A100 and perhaps other models that
# do not have clearly defined TVM matching target so we need to manually map them to the correct one.
if gpu_model == "NVIDIA PG506-230":
gpu_model = "NVIDIA A100"

target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda"
return target