diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index d54cb669..94d5d07d 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -907,6 +907,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMLA(typename AttentionVariant:: vec_t q_nope_vec[tile_size_qo_heads]; vec_t q_pe_vec[tile_size_qo_heads]; state_t st[tile_size_qo_heads]; + uint32_t qo_head_idx[tile_size_qo_heads]; vec_t freq; #pragma unroll @@ -918,10 +919,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMLA(typename AttentionVariant:: // load q_nope and q_pe tile #pragma unroll for (int i = 0; i < tile_size_qo_heads; ++i) { - const uint32_t qo_head_idx = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i); - q_nope_vec[i].cast_load(q_nope + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim_ckv + tx * vec_size_ckv); - q_pe_vec[i] = vec_apply_llama_rope_interleave( - q_pe + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim_kpe, freq, q_offset_val); + qo_head_idx[i] = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i); + if (qo_head_idx[i] < num_qo_heads) { + q_nope_vec[i].cast_load(q_nope + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv); + q_pe_vec[i] = vec_apply_llama_rope_interleave( + q_pe + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_kpe, freq, q_offset_val); + } } // init paged-cache read offset to be used @@ -1022,20 +1025,21 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMLA(typename AttentionVariant:: if (bdz != 1) { #pragma unroll for (int i = 0; i < tile_size_qo_heads; ++i) { - sync_state(variant, st[i], (float*)smem, smem_md); + if (qo_head_idx[i] < num_qo_heads) + sync_state(variant, st[i], (float*)smem, smem_md); } } if (tz == 0) { #pragma unroll for (int i = 0; i < tile_size_qo_heads; ++i) { - const uint32_t qo_head_idx = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i); - - st[i].normalize(); - st[i].o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim_ckv + tx * vec_size_ckv); + if (qo_head_idx[i] < num_qo_heads) { + st[i].normalize(); + st[i].o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv); - if (lse != nullptr) { - lse[batch_idx * num_qo_heads + qo_head_idx] = st[i].get_lse(); + if (lse != nullptr) { + lse[batch_idx * num_qo_heads + qo_head_idx[i]] = st[i].get_lse(); + } } } } @@ -1062,12 +1066,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(typename AttentionVariant:: constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads; constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); - if (num_qo_heads % qo_heads_per_block != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of qo_heads_per_block " - << qo_heads_per_block; - throw std::invalid_argument(err_msg.str()); - } const uint32_t gdy = ceil_div(num_qo_heads, qo_heads_per_block); auto compute_capacity = GetCudaComputeCapability(); diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index c804887a..a83f9397 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -218,12 +218,6 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA( constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads; constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); - if (num_qo_heads % qo_heads_per_block != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of qo_heads_per_block " - << qo_heads_per_block; - throw std::invalid_argument(err_msg.str()); - } const uint32_t gdy = ceil_div(num_qo_heads, qo_heads_per_block); const uint32_t smem_size = diff --git a/src/bench_batch_decode_mla.cu b/src/bench_batch_decode_mla.cu index 4a7da7a4..43aa8a64 100644 --- a/src/bench_batch_decode_mla.cu +++ b/src/bench_batch_decode_mla.cu @@ -124,6 +124,6 @@ void bench_flashinfer_batch_decode_mla(nvbench::state& state) { .add_int64_axis("page_size", {64}) \ .add_int64_axis("batch_size", {16, 256}) \ .add_int64_axis("seqlen", {1024, 16384}) \ - .add_int64_axis("num_qo_heads", {8, 16, 32, 64, 128}) + .add_int64_axis("num_qo_heads", {8, 16, 32, 40, 64, 128}) BENCH_FLASHINFER_BATCH_DECODE(half);