Skip to content

Commit

Permalink
[ROCm] Fixup arch checks for ROCM (vllm-project#2627)
Browse files Browse the repository at this point in the history
  • Loading branch information
dllehr-amd authored and jimpang committed Mar 4, 2024
1 parent 464b858 commit 068e4e6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 37 deletions.
3 changes: 0 additions & 3 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ RUN echo "Base image is $BASE_IMAGE"
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

# this does not always work for all rocm versions
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"

ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
Expand Down
90 changes: 56 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Supported NVIDIA GPU architectures.
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)


Expand Down Expand Up @@ -63,22 +63,6 @@ def _is_cuda() -> bool:
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]


def get_amdgpu_offload_arch():
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
try:
output = subprocess.check_output([command])
return output.decode('utf-8').strip()
except subprocess.CalledProcessError as e:
error_message = f"Error: {e}"
raise RuntimeError(error_message) from e
except FileNotFoundError as e:
# If the command is not found, print an error message
error_message = f"The command {command} was not found."
raise RuntimeError(error_message) from e

return None


def get_hipcc_rocm_version():
# Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'],
Expand Down Expand Up @@ -138,6 +122,50 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
return nvcc_cuda_version


def get_pytorch_rocm_arch() -> Set[str]:
"""Get the cross section of Pytorch,and vllm supported gfx arches
ROCM can get the supported gfx architectures in one of two ways
Either through the PYTORCH_ROCM_ARCH env var, or output from
rocm_agent_enumerator.
In either case we can generate a list of supported arch's and
cross reference with VLLM's own ROCM_SUPPORTED_ARCHs.
"""
env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)

# If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
if env_arch_list is None:
command = "rocm_agent_enumerator"
env_arch_list = subprocess.check_output([command]).decode('utf-8')\
.strip().replace("\n", ";")
arch_source_str = "rocm_agent_enumerator"
else:
arch_source_str = "PYTORCH_ROCM_ARCH env variable"

# List are separated by ; or space.
pytorch_rocm_arch = set(env_arch_list.replace(" ", ";").split(";"))

# Filter out the invalid architectures and print a warning.
arch_list = pytorch_rocm_arch.intersection(ROCM_SUPPORTED_ARCHS)

# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
f"None of the ROCM architectures in {arch_source_str} "
f"({env_arch_list}) is supported. "
f"Supported ROCM architectures are: {ROCM_SUPPORTED_ARCHS}.")
invalid_arch_list = pytorch_rocm_arch - ROCM_SUPPORTED_ARCHS
if invalid_arch_list:
warnings.warn(
f"Unsupported ROCM architectures ({invalid_arch_list}) are "
f"excluded from the {arch_source_str} output "
f"({env_arch_list}). Supported ROCM architectures are: "
f"{ROCM_SUPPORTED_ARCHS}.",
stacklevel=2)
return arch_list


def get_torch_arch_list() -> Set[str]:
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
Expand All @@ -162,22 +190,27 @@ def get_torch_arch_list() -> Set[str]:
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. "
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
f"Supported CUDA architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list:
warnings.warn(
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
f"({env_arch_list}). Supported CUDA architectures are: "
f"{valid_archs}.",
stacklevel=2)
return arch_list


# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if _is_hip():
rocm_arches = get_pytorch_rocm_arch()
NVCC_FLAGS += ["--offload-arch=" + arch for arch in rocm_arches]
else:
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()

if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
Expand Down Expand Up @@ -286,17 +319,6 @@ def get_torch_arch_list() -> Set[str]:
"nvcc": NVCC_FLAGS_PUNICA,
},
))
elif _is_hip():
amd_archs = os.getenv("GPU_ARCHS")
if amd_archs is None:
amd_archs = get_amdgpu_offload_arch()
for arch in amd_archs.split(";"):
if arch not in ROCM_SUPPORTED_ARCHS:
raise RuntimeError(
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
f"amdgpu_arch_found: {arch}")
NVCC_FLAGS += [f"--offload-arch={arch}"]

elif _is_neuron():
neuronxcc_version = get_neuronxcc_version()

Expand Down

0 comments on commit 068e4e6

Please sign in to comment.