Skip to content

Commit

Permalink
bugfix: Fix sm75 kernel configuration (#449)
Browse files Browse the repository at this point in the history
Some kernel configurations are not compatible with sm75, this pr fix
these issues.
  • Loading branch information
yzh119 authored Aug 27, 2024
1 parent f1c0b68 commit 3d38d0d
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 338 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
# required: true

env:
TORCH_CUDA_ARCH_LIST: "8.0 8.9 9.0+PTX"
TORCH_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX"

jobs:
build:
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Prerequisites

- Use ``python -c "import torch; print(torch.version.cuda)"`` to check your PyTorch CUDA version.

- Supported GPU architectures: ``sm80``, ``sm86``, ``sm89``, ``sm90`` (``sm75`` / ``sm70`` support is working in progress).
- Supported GPU architectures: ``sm75``, ``sm80``, ``sm86``, ``sm89``, ``sm90``.

Quick Start
^^^^^^^^^^^
Expand Down
252 changes: 128 additions & 124 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@
#include <cooperative_groups.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include <cstddef>
#ifdef FLASHINFER_ENABLE_FP8
#include <cuda_fp8.h>
#endif
#include <cuda_runtime.h>

#include <cstddef>
#include <cuda/pipeline>
#include <iostream>
#include <optional>
Expand Down Expand Up @@ -537,6 +534,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
j] +
tx * vec_size;
}

// load k tiles
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
Expand Down Expand Up @@ -597,11 +595,7 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo
return 512U;
}
} else {
#ifdef FLASHINFER_ENABLE_BF16
return 128U;
#else
return 64U;
#endif
}
}

Expand Down Expand Up @@ -639,8 +633,8 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
auto compute_capacity = GetCudaComputeCapability();
static_assert(bdx <= 32U);
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
constexpr uint32_t bdy = GROUP_SIZE;
Expand All @@ -649,69 +643,74 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
constexpr uint32_t bdz = num_threads / (bdx * bdy);
tensor_info_t info(1, seq_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U;
const uint32_t smem_size =
2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) +
2U * bdy * bdz * sizeof(float);
auto kernel = SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (seq_len <= 256 || tmp == nullptr) {
// no need to use partition-kv kernel
dim3 nblks = dim3(1, num_kv_heads);
dim3 nthrs = dim3(bdx, bdy, bdz);
float* lse = nullptr;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
(void*)&seq_len};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm);
uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads;
uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256);
uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size);
dim3 nblks = dim3(num_chunks, num_kv_heads);
if (nblks.x == 0 || nblks.y == 0) {
std::ostringstream err_msg;
err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")";
throw std::runtime_error(err_msg.str());
}
dim3 nthrs = dim3(bdx, bdy, bdz);
float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM);
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&tmp,
(void*)&tmp_lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
(void*)&kv_chunk_size};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, {
const uint32_t smem_size =
2U * NUM_STAGES_SMEM * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) +
2U * bdy * bdz * sizeof(float);
auto kernel = SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE,
NUM_STAGES_SMEM, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream));
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (seq_len <= 256 || tmp == nullptr) {
// no need to use partition-kv kernel
dim3 nblks = dim3(1, num_kv_heads);
dim3 nthrs = dim3(bdx, bdy, bdz);
float* lse = nullptr;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
(void*)&seq_len};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, kernel, num_threads, smem_size));
uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm);
uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads;
uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256);
uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size);
dim3 nblks = dim3(num_chunks, num_kv_heads);
if (nblks.x == 0 || nblks.y == 0) {
std::ostringstream err_msg;
err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")";
throw std::runtime_error(err_msg.str());
}
dim3 nthrs = dim3(bdx, bdy, bdz);
float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM);
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&tmp,
(void*)&tmp_lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
(void*)&kv_chunk_size};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(
MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream));
}
});
});
return cudaSuccess;
}
Expand All @@ -730,66 +729,71 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
const uint32_t num_kv_heads = paged_kv.num_heads;

constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
auto compute_capacity = GetCudaComputeCapability();
constexpr uint32_t bdx = HEAD_DIM / vec_size;
static_assert(bdx <= 32);
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U;
const uint32_t smem_size =
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));
auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE, num_stages_smem,
tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage,
DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (tmp_v == nullptr) {
// do not use partition-kv kernel
bool partition_kv = false;
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);

void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
bool partition_kv = true;
void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse,
kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream));
}
DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, {
const uint32_t smem_size =
2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*),
2 * bdy * bdz * sizeof(float));
auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE, NUM_STAGES_SMEM,
tile_size_per_bdx, vec_size, bdx, bdy, bdz,
page_storage, DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (tmp_v == nullptr) {
// do not use partition-kv kernel
bool partition_kv = false;
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);

void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
bool partition_kv = true;
void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse,
kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream));
}
});
});
return cudaSuccess;
}
Expand Down
Loading

0 comments on commit 3d38d0d

Please sign in to comment.