Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Aug 26, 2024
1 parent 4c04d62 commit 7b7c7f4
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
# required: true

env:
TORCH_CUDA_ARCH_LIST: "8.0 8.9 9.0+PTX"
TORCH_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX"

jobs:
build:
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Prerequisites

- Use ``python -c "import torch; print(torch.version.cuda)"`` to check your PyTorch CUDA version.

- Supported GPU architectures: ``sm80``, ``sm86``, ``sm89``, ``sm90`` (``sm75`` / ``sm70`` support is working in progress).
- Supported GPU architectures: ``sm75``, ``sm80``, ``sm86``, ``sm89``, ``sm90``.

Quick Start
^^^^^^^^^^^
Expand Down
4 changes: 0 additions & 4 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -594,11 +594,7 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo
return 512U;
}
} else {
#ifdef FLASHINFER_ENABLE_BF16
return 128U;
#else
return 64U;
#endif
}
}

Expand Down
9 changes: 3 additions & 6 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,14 @@

root = pathlib.Path(__name__).parent

enable_bf16 = True
# NOTE(Zihao): we haven't utilized fp8 tensor cores yet, so there is no
# cuda arch check for fp8 at the moment.
enable_fp8 = True
for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags():
arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:])
if arch < 75:
raise RuntimeError("FlashInfer requires sm75+")
elif arch == 75:
# disable bf16 for sm75
enable_bf16 = False

enable_bf16 = os.environ.get("FLASHINFER_ENABLE_BF16", "1") == "1"
enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1"

if enable_bf16:
torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_BF16")
Expand Down

0 comments on commit 7b7c7f4

Please sign in to comment.