From bc9b9aba7630696ddda41e170ea133dcb90c2857 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 26 Aug 2024 22:19:23 +0000 Subject: [PATCH] upd --- .github/workflows/release_wheel.yml | 2 +- docs/installation.rst | 2 +- include/flashinfer/attention/decode.cuh | 4 ---- python/setup.py | 9 +++------ 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/.github/workflows/release_wheel.yml b/.github/workflows/release_wheel.yml index 321d268d..aa9b1265 100644 --- a/.github/workflows/release_wheel.yml +++ b/.github/workflows/release_wheel.yml @@ -18,7 +18,7 @@ on: # required: true env: - TORCH_CUDA_ARCH_LIST: "8.0 8.9 9.0+PTX" + TORCH_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX" jobs: build: diff --git a/docs/installation.rst b/docs/installation.rst index 95fbf84a..266ebbdb 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -19,7 +19,7 @@ Prerequisites - Use ``python -c "import torch; print(torch.version.cuda)"`` to check your PyTorch CUDA version. -- Supported GPU architectures: ``sm80``, ``sm86``, ``sm89``, ``sm90`` (``sm75`` / ``sm70`` support is working in progress). +- Supported GPU architectures: ``sm75``, ``sm80``, ``sm86``, ``sm89``, ``sm90``. Quick Start ^^^^^^^^^^^ diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index c8a7c75d..a84620a5 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -594,11 +594,7 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo return 512U; } } else { -#ifdef FLASHINFER_ENABLE_BF16 return 128U; -#else - return 64U; -#endif } } diff --git a/python/setup.py b/python/setup.py index 2fd605be..22d2878a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -32,17 +32,14 @@ root = pathlib.Path(__name__).parent -enable_bf16 = True -# NOTE(Zihao): we haven't utilized fp8 tensor cores yet, so there is no # cuda arch check for fp8 at the moment. -enable_fp8 = True for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) if arch < 75: raise RuntimeError("FlashInfer requires sm75+") - elif arch == 75: - # disable bf16 for sm75 - enable_bf16 = False + +enable_bf16 = os.environ.get("FLASHINFER_ENABLE_BF16", "1") == "1" +enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1" if enable_bf16: torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_BF16")