From cd25a2692207992bff04350d549c1b30c68d8d1c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 1 Oct 2023 04:52:43 +0000 Subject: [PATCH 1/2] Fix error msg --- setup.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 8b2ad97dd5403..815d7834cf938 100644 --- a/setup.py +++ b/setup.py @@ -50,16 +50,17 @@ def get_torch_arch_list() -> Set[str]: # not give the best performance on the newer architectures, it provides # forward compatibility. valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS] - arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) - if arch_list is None: + env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) + if env_arch_list is None: return set() # List are separated by ; or space. - arch_list = arch_list.replace(" ", ";").split(";") + arch_list = env_arch_list.replace(" ", ";").split(";") for arch in arch_list: if arch not in valid_arch_strs: raise ValueError( - f"Unsupported CUDA arch ({arch}). " + f"Unsupported CUDA arch ({arch}) is included in the " + f"`TORCH_CUDA_ARCH_LIST` env variable ({env_arch_list}). " f"Valid CUDA arch strings are: {valid_arch_strs}.") return set(arch_list) From 03e2f9ee5261c5491fae1cf2f82c5cbd2718d878 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 13 Oct 2023 20:58:35 +0000 Subject: [PATCH 2/2] Fix err msg --- setup.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 815d7834cf938..6ffc03c25386d 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ ROOT_DIR = os.path.dirname(__file__) # Supported NVIDIA GPU architectures. -SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"] +SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] @@ -49,20 +49,32 @@ def get_torch_arch_list() -> Set[str]: # and executed on the 8.6 or newer architectures. While the PTX code will # not give the best performance on the newer architectures, it provides # forward compatibility. - valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS] env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) if env_arch_list is None: return set() # List are separated by ; or space. - arch_list = env_arch_list.replace(" ", ";").split(";") - for arch in arch_list: - if arch not in valid_arch_strs: - raise ValueError( - f"Unsupported CUDA arch ({arch}) is included in the " - f"`TORCH_CUDA_ARCH_LIST` env variable ({env_arch_list}). " - f"Valid CUDA arch strings are: {valid_arch_strs}.") - return set(arch_list) + torch_arch_list = set(env_arch_list.replace(" ", ";").split(";")) + if not torch_arch_list: + return set() + + # Filter out the invalid architectures and print a warning. + valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS}) + arch_list = torch_arch_list.intersection(valid_archs) + # If none of the specified architectures are valid, raise an error. + if not arch_list: + raise RuntimeError( + "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " + f"variable ({env_arch_list}) is supported. " + f"Supported CUDA architectures are: {valid_archs}.") + invalid_arch_list = torch_arch_list - valid_archs + if invalid_arch_list: + warnings.warn( + f"Unsupported CUDA architectures ({invalid_arch_list}) are " + "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " + f"({env_arch_list}). Supported CUDA architectures are: " + f"{valid_archs}.") + return arch_list # First, check the TORCH_CUDA_ARCH_LIST environment variable. @@ -82,7 +94,7 @@ def get_torch_arch_list() -> Set[str]: if not compute_capabilities: # If no GPU is specified nor available, add all supported architectures # based on the NVCC CUDA version. - compute_capabilities = set(SUPPORTED_ARCHS) + compute_capabilities = SUPPORTED_ARCHS.copy() if nvcc_cuda_version < Version("11.1"): compute_capabilities.remove("8.6") if nvcc_cuda_version < Version("11.8"):