Skip to content

Commit

Permalink
support num_qo_heads can be non-integral multiple of qo_heads_per_blo…
Browse files Browse the repository at this point in the history
…ck to accommodate MiniCPM3-4B who has 40 num_qo_heads
  • Loading branch information
tsu-bin committed Oct 30, 2024
1 parent b14f7a1 commit 478c835
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 24 deletions.
32 changes: 15 additions & 17 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMLA(typename AttentionVariant::
vec_t<float, vec_size_ckv> q_nope_vec[tile_size_qo_heads];
vec_t<float, vec_size_kpe> q_pe_vec[tile_size_qo_heads];
state_t<vec_size_ckv> st[tile_size_qo_heads];
uint32_t qo_head_idx[tile_size_qo_heads];

vec_t<float, vec_size_kpe> freq;
#pragma unroll
Expand All @@ -918,10 +919,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMLA(typename AttentionVariant::
// load q_nope and q_pe tile
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
const uint32_t qo_head_idx = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i);
q_nope_vec[i].cast_load(q_nope + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim_ckv + tx * vec_size_ckv);
q_pe_vec[i] = vec_apply_llama_rope_interleave<vec_size_kpe, bdx>(
q_pe + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim_kpe, freq, q_offset_val);
qo_head_idx[i] = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i);
if (qo_head_idx[i] < num_qo_heads) {
q_nope_vec[i].cast_load(q_nope + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv);
q_pe_vec[i] = vec_apply_llama_rope_interleave<vec_size_kpe, bdx>(
q_pe + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_kpe, freq, q_offset_val);
}
}

// init paged-cache read offset to be used
Expand Down Expand Up @@ -1022,20 +1025,21 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMLA(typename AttentionVariant::
if (bdz != 1) {
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
sync_state<vec_size_ckv, bdx, bdy, bdz>(variant, st[i], (float*)smem, smem_md);
if (qo_head_idx[i] < num_qo_heads)
sync_state<vec_size_ckv, bdx, bdy, bdz>(variant, st[i], (float*)smem, smem_md);
}
}

if (tz == 0) {
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
const uint32_t qo_head_idx = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i);

st[i].normalize();
st[i].o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim_ckv + tx * vec_size_ckv);
if (qo_head_idx[i] < num_qo_heads) {
st[i].normalize();
st[i].o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv);

if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx] = st[i].get_lse();
if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx[i]] = st[i].get_lse();
}
}
}
}
Expand All @@ -1062,12 +1066,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(typename AttentionVariant::
constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
if (num_qo_heads % qo_heads_per_block != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of qo_heads_per_block "
<< qo_heads_per_block;
throw std::invalid_argument(err_msg.str());
}
const uint32_t gdy = ceil_div(num_qo_heads, qo_heads_per_block);

auto compute_capacity = GetCudaComputeCapability();
Expand Down
6 changes: 0 additions & 6 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,6 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA(
constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
if (num_qo_heads % qo_heads_per_block != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of qo_heads_per_block "
<< qo_heads_per_block;
throw std::invalid_argument(err_msg.str());
}
const uint32_t gdy = ceil_div(num_qo_heads, qo_heads_per_block);

const uint32_t smem_size =
Expand Down
2 changes: 1 addition & 1 deletion src/bench_batch_decode_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,6 @@ void bench_flashinfer_batch_decode_mla(nvbench::state& state) {
.add_int64_axis("page_size", {64}) \
.add_int64_axis("batch_size", {16, 256}) \
.add_int64_axis("seqlen", {1024, 16384}) \
.add_int64_axis("num_qo_heads", {8, 16, 32, 64, 128})
.add_int64_axis("num_qo_heads", {8, 16, 32, 40, 64, 128})

BENCH_FLASHINFER_BATCH_DECODE(half);

0 comments on commit 478c835

Please sign in to comment.