From 2d49f438f5aeb44cf359d5d5932fde66f8b47ec5 Mon Sep 17 00:00:00 2001 From: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com> Date: Mon, 5 Feb 2024 16:59:09 -0600 Subject: [PATCH] [ROCm] Fixup arch checks for ROCM (#2627) --- Dockerfile.rocm | 3 -- setup.py | 90 ++++++++++++++++++++++++++++++------------------- 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 88172fb73b937..3c76305303037 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -10,9 +10,6 @@ RUN echo "Base image is $BASE_IMAGE" # BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" # BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" -# this does not always work for all rocm versions -RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \ - echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH" ARG FA_GFX_ARCHS="gfx90a;gfx942" RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" diff --git a/setup.py b/setup.py index 3e2127855a755..0c4937da210ef 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ # Supported NVIDIA GPU architectures. NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} -ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"} +ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"} # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) @@ -63,22 +63,6 @@ def _is_cuda() -> bool: NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -def get_amdgpu_offload_arch(): - command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" - try: - output = subprocess.check_output([command]) - return output.decode('utf-8').strip() - except subprocess.CalledProcessError as e: - error_message = f"Error: {e}" - raise RuntimeError(error_message) from e - except FileNotFoundError as e: - # If the command is not found, print an error message - error_message = f"The command {command} was not found." - raise RuntimeError(error_message) from e - - return None - - def get_hipcc_rocm_version(): # Run the hipcc --version command result = subprocess.run(['hipcc', '--version'], @@ -138,6 +122,50 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: return nvcc_cuda_version +def get_pytorch_rocm_arch() -> Set[str]: + """Get the cross section of Pytorch,and vllm supported gfx arches + + ROCM can get the supported gfx architectures in one of two ways + Either through the PYTORCH_ROCM_ARCH env var, or output from + rocm_agent_enumerator. + + In either case we can generate a list of supported arch's and + cross reference with VLLM's own ROCM_SUPPORTED_ARCHs. + """ + env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None) + + # If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator + if env_arch_list is None: + command = "rocm_agent_enumerator" + env_arch_list = subprocess.check_output([command]).decode('utf-8')\ + .strip().replace("\n", ";") + arch_source_str = "rocm_agent_enumerator" + else: + arch_source_str = "PYTORCH_ROCM_ARCH env variable" + + # List are separated by ; or space. + pytorch_rocm_arch = set(env_arch_list.replace(" ", ";").split(";")) + + # Filter out the invalid architectures and print a warning. + arch_list = pytorch_rocm_arch.intersection(ROCM_SUPPORTED_ARCHS) + + # If none of the specified architectures are valid, raise an error. + if not arch_list: + raise RuntimeError( + f"None of the ROCM architectures in {arch_source_str} " + f"({env_arch_list}) is supported. " + f"Supported ROCM architectures are: {ROCM_SUPPORTED_ARCHS}.") + invalid_arch_list = pytorch_rocm_arch - ROCM_SUPPORTED_ARCHS + if invalid_arch_list: + warnings.warn( + f"Unsupported ROCM architectures ({invalid_arch_list}) are " + f"excluded from the {arch_source_str} output " + f"({env_arch_list}). Supported ROCM architectures are: " + f"{ROCM_SUPPORTED_ARCHS}.", + stacklevel=2) + return arch_list + + def get_torch_arch_list() -> Set[str]: # TORCH_CUDA_ARCH_LIST can have one or more architectures, # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the @@ -162,22 +190,27 @@ def get_torch_arch_list() -> Set[str]: # If none of the specified architectures are valid, raise an error. if not arch_list: raise RuntimeError( - "None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env " + "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " f"variable ({env_arch_list}) is supported. " - f"Supported CUDA/ROCM architectures are: {valid_archs}.") + f"Supported CUDA architectures are: {valid_archs}.") invalid_arch_list = torch_arch_list - valid_archs if invalid_arch_list: warnings.warn( - f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are " + f"Unsupported CUDA architectures ({invalid_arch_list}) are " "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " - f"({env_arch_list}). Supported CUDA/ROCM architectures are: " + f"({env_arch_list}). Supported CUDA architectures are: " f"{valid_archs}.", stacklevel=2) return arch_list -# First, check the TORCH_CUDA_ARCH_LIST environment variable. -compute_capabilities = get_torch_arch_list() +if _is_hip(): + rocm_arches = get_pytorch_rocm_arch() + NVCC_FLAGS += ["--offload-arch=" + arch for arch in rocm_arches] +else: + # First, check the TORCH_CUDA_ARCH_LIST environment variable. + compute_capabilities = get_torch_arch_list() + if _is_cuda() and not compute_capabilities: # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available # GPUs on the current machine. @@ -286,17 +319,6 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS_PUNICA, }, )) -elif _is_hip(): - amd_archs = os.getenv("GPU_ARCHS") - if amd_archs is None: - amd_archs = get_amdgpu_offload_arch() - for arch in amd_archs.split(";"): - if arch not in ROCM_SUPPORTED_ARCHS: - raise RuntimeError( - f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}" - f"amdgpu_arch_found: {arch}") - NVCC_FLAGS += [f"--offload-arch={arch}"] - elif _is_neuron(): neuronxcc_version = get_neuronxcc_version()