diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index b2086ee2..821be5cd 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -849,13 +849,11 @@ template paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, - float* lse, bool* block_valid_mask, std::optional fixed_grid_size, float sm_scale, + float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; const uint32_t num_kv_heads = paged_kv.num_heads; - const uint32_t batch_size = paged_kv.batch_size; - const uint32_t grid_size = fixed_grid_size.value_or(batch_size * num_kv_heads); const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); @@ -872,7 +870,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( if (tmp_v == nullptr) { // do not use partition-kv kernel - dim3 nblks(grid_size / num_kv_heads, num_kv_heads); + dim3 nblks(padded_batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); auto kernel = BatchDecodeWithPagedKVCacheKernel 0) { - fixed_grid_size_ = padded_batch_size_after_partition * num_kv_heads; + padded_batch_size_ = padded_batch_size_after_partition; AlignedAllocator allocator(buffer, workspace_size_in_bytes); tmp_v_ = allocator.aligned_alloc( num_qo_heads * padded_batch_size_after_partition * HEAD_DIM * sizeof(DTypeOut), 16); @@ -367,11 +367,13 @@ class BatchDecodeHandler { /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); } else { block_valid_mask_ = nullptr; - fixed_grid_size_ = num_kv_heads * batch_size; + padded_batch_size_ = num_kv_heads * batch_size; } } else { // NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled. block_valid_mask_ = nullptr; + // do not pad the batch size when not using CUDAGraph + padded_batch_size_ = batch_size_after_partition_; if (tmp_size > 0) { AlignedAllocator allocator(buffer, workspace_size_in_bytes); tmp_v_ = allocator.aligned_alloc(tmp_size, 16); @@ -418,7 +420,7 @@ class BatchDecodeHandler { cudaError_t EndForward() { forward_started_ = false; - fixed_grid_size_ = 0; + padded_batch_size_ = 0; batch_size_before_partition_ = 0; batch_size_after_partition_ = 0; block_valid_mask_ = nullptr; @@ -492,7 +494,7 @@ class BatchDecodeHandler { void* seq_lengths_before_partition_; bool forward_started_; bool cuda_graph_enabled_; - uint32_t fixed_grid_size_; + uint32_t padded_batch_size_; uint32_t fixed_batch_size_; cudaStream_t stream_; }; diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index 7e01a140..b28aefa2 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -40,7 +40,7 @@ template paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, - float* lse, bool* block_valid_mask, std::optional fixed_grid_size, float sm_scale, + float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template ( q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse, - handler->GetBlockValidMask(), - (handler->IsCUDAGraphEnabled() ? std::optional(handler->GetFixedGridSize()) - : std::nullopt), - sm_scale, rope_scale, rope_theta, stream); + handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta, + stream); } } // namespace flashinfer diff --git a/python/generate_batch_paged_decode_inst.py b/python/generate_batch_paged_decode_inst.py index 3d0da642..1e72a9a8 100644 --- a/python/generate_batch_paged_decode_inst.py +++ b/python/generate_batch_paged_decode_inst.py @@ -39,8 +39,7 @@ def get_cu_file_str( paged_kv_t paged_kv, kv_partition_info_t<{idtype}> kv_partition_info, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, - bool* block_valid_mask, - std::optional fixed_grid_size, + bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index fdcfce9b..e0b793ff 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -89,10 +89,9 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { } else { state.exec([&](nvbench::launch&) { cudaError_t status = - BatchDecodeWithPagedKVCache( + BatchDecodeWithPagedKVCacheNoSplitKV( thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv, - kv_partition_info_t(), thrust::raw_pointer_cast(o.data()), /*tmp_v=*/nullptr, - /*tmp_s=*/nullptr, + kv_partition_info_t(), thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 5a4007d8..cb3e1f15 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -209,10 +209,10 @@ cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTy template -cudaError_t BatchDecodeWithPagedKVCache( +cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, - kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, - float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + kv_partition_info_t kv_partition_info, DTypeOut* o, float* lse, uint32_t num_qo_heads, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t num_kv_heads = paged_kv.num_heads; @@ -233,9 +233,10 @@ cudaError_t BatchDecodeWithPagedKVCache( return BatchDecodeWithPagedKVCacheDispatched( - q, q_offset, paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse, - /*block_valid_mask=*/nullptr, std::nullopt, sm_scale, rope_scale, rope_theta, - stream); + q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, + lse, + /*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, sm_scale, + rope_scale, rope_theta, stream); })})}); return cudaSuccess; diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 5fcaca42..6b14372d 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -107,11 +107,11 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si if (!cooperative) { // use non-cooperative kernel - cudaError_t status = - flashinfer::BatchDecodeWithPagedKVCache( - thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, - kv_partition_info_t(), thrust::raw_pointer_cast(o_device.data()), - /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); + cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheNoSplitKV( + thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, + kv_partition_info_t(), thrust::raw_pointer_cast(o_device.data()), + /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } else { cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheWrapper