diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 7c529330..ce0388bb 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -325,6 +325,10 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa /*! * \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of index * sets at each position might vary. + * + * For CUDA graph support, the kernel can be built with a maximum sequence length and executed + * using a truncated, dynamic sequence length passed through `seq_len_ptr`. + * * \tparam vec_size The vector size used in the kernel. * \tparam bdx The blockDim.x used in the kernel. * \tparam bdy The blockDim.y used in the kernel. @@ -336,20 +340,22 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa * \param indptr The start offsets of each position in the variable length array. * \param v_merged The merged v of index sets union. (n, h, d) * \param s_merged The merged logsumexp value of index sets union. (n, h) + * \param max_seq_len The maximum sequence length supported by the kernel. + * \param seq_len_ptr The current sequence length (number of positions populated in indptr). * \param num_heads The number of heads of v. * \param head_dim The dimension of each head. * \note s are logsumexp values with base 2. */ template -__global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, - float* __restrict__ S, IdType* indptr, - DTypeO* __restrict__ v_merged, - float* __restrict__ s_merged, - uint32_t seq_len, uint32_t num_heads) { +__global__ void PersistentVariableLengthMergeStatesKernel( + DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged, + float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr, + uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; + const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len; uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; constexpr uint32_t head_dim = vec_size * bdx; @@ -437,10 +443,13 @@ template __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ V, IdType* indptr, DTypeO* __restrict__ v_sum, - uint32_t seq_len, uint32_t num_heads) { + uint32_t max_seq_len, + uint32_t* __restrict__ seq_len_ptr, + uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; + const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len; uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; constexpr uint32_t head_dim = vec_size * bdx; @@ -641,8 +650,9 @@ cudaError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uin template cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged, - float* s_merged, uint32_t seq_len, uint32_t num_heads, - uint32_t head_dim, cudaStream_t stream = nullptr) { + float* s_merged, uint32_t max_seq_len, uint32_t* seq_len, + uint32_t num_heads, uint32_t head_dim, + cudaStream_t stream = nullptr) { int dev_id = 0; int num_sms = 0; int num_blocks_per_sm = 0; @@ -661,11 +671,11 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp DTypeIn, DTypeO, IdType>; FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, num_threads, smem_size)); - num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms)); + num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms)); dim3 nblks(num_sms * num_blocks_per_sm); dim3 nthrs(bdx, bdy); - void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads}; + void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &max_seq_len, &seq_len, &num_heads}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); @@ -674,9 +684,9 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp } template -cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, uint32_t seq_len, - uint32_t num_heads, uint32_t head_dim, - cudaStream_t stream = nullptr) { +cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, + uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads, + uint32_t head_dim, cudaStream_t stream = nullptr) { int dev_id = 0; int num_sms = 0; int num_blocks_per_sm = 0; @@ -694,11 +704,11 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum DTypeIn, DTypeO, IdType>; FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, num_threads, smem_size)); - num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms)); + num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms)); dim3 nblks(num_sms * num_blocks_per_sm); dim3 nthrs(bdx, bdy); - void* args[] = {&v, &indptr, &v_sum, &seq_len, &num_heads}; + void* args[] = {&v, &indptr, &v_sum, &max_seq_len, &seq_len, &num_heads}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index c8543b3f..0394c3eb 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -764,12 +764,12 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::Par cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse, - params.paged_kv.batch_size, num_qo_heads, - HEAD_DIM, stream)); + params.paged_kv.batch_size, nullptr, + num_qo_heads, HEAD_DIM, stream)); } else { FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, o, params.o_indptr, - params.paged_kv.batch_size, num_qo_heads, - HEAD_DIM, stream)); + params.paged_kv.batch_size, nullptr, + num_qo_heads, HEAD_DIM, stream)); } } }); @@ -1087,8 +1087,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(typename AttentionVariant:: dim3 nthrs(bdx, bdy, bdz); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse, - params.paged_kv.batch_size, num_qo_heads, - HEAD_DIM_CKV, stream)); + params.paged_kv.batch_size, nullptr, + num_qo_heads, HEAD_DIM_CKV, stream)); } }); return cudaSuccess; diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 8464f053..5a7bac6a 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -2199,11 +2199,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, - total_num_rows, num_qo_heads, HEAD_DIM, - stream)); + total_num_rows, nullptr, num_qo_heads, + HEAD_DIM, stream)); } else { - FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( - tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream)); + FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, + total_num_rows, nullptr, num_qo_heads, + HEAD_DIM, stream)); } } } @@ -2300,11 +2301,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, - total_num_rows, num_qo_heads, HEAD_DIM, - stream)); + total_num_rows, nullptr, num_qo_heads, + HEAD_DIM, stream)); } else { - FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( - tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream)); + FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, + total_num_rows, nullptr, num_qo_heads, + HEAD_DIM, stream)); } } } diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 4530ea9b..0cd83fa8 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -100,8 +100,8 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads, thrust::raw_pointer_cast(S_ragged_device.data()), thrust::raw_pointer_cast(indptr_device.data()), thrust::raw_pointer_cast(V_merged_1_device.data()), - thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, num_heads, - head_dim); + thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, nullptr, + num_heads, head_dim); thrust::host_vector V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device); thrust::host_vector S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device); @@ -133,6 +133,81 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads, EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed."; } +template +void _TestVariableLengthMergeKernelPaddedCorrectness(size_t max_seq_len, size_t seq_len) { + ASSERT_LE(seq_len, max_seq_len); + + const size_t num_heads = 4; + const size_t head_dim = 64; + const uint32_t max_num_index_sets = 512; + + std::vector lengths(max_seq_len); + utils::vec_randint_(lengths, 1, max_num_index_sets); + std::vector indptr(max_seq_len + 1, 0); + for (size_t i = 0; i < seq_len; ++i) { + indptr[i + 1] = indptr[i] + lengths[i]; + } + + uint32_t last_indptr = indptr[seq_len]; + std::vector V_ragged_host(last_indptr * num_heads * head_dim); + std::vector S_ragged_host(last_indptr * num_heads); + + utils::vec_normal_(V_ragged_host); + utils::vec_uniform_(S_ragged_host, -10, 10); + + thrust::device_vector V_ragged_device(V_ragged_host); + thrust::device_vector S_ragged_device(S_ragged_host); + thrust::device_vector indptr_device(indptr); + thrust::device_vector V_merged_0_device(max_seq_len * num_heads * head_dim); + thrust::device_vector V_merged_1_device(max_seq_len * num_heads * head_dim); + thrust::device_vector S_merged_0_device(max_seq_len * num_heads); + thrust::device_vector S_merged_1_device(max_seq_len * num_heads); + thrust::device_vector seq_len_device( + std::vector{static_cast(seq_len)}); + + // Reference: use VariableLengthMergeStates on the precisely-sized input. + VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()), + thrust::raw_pointer_cast(S_ragged_device.data()), + thrust::raw_pointer_cast(indptr_device.data()), + thrust::raw_pointer_cast(V_merged_0_device.data()), + thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, nullptr, + num_heads, head_dim); + // Expected: use VariableLengthMergeStates on a padded input + VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()), + thrust::raw_pointer_cast(S_ragged_device.data()), + thrust::raw_pointer_cast(indptr_device.data()), + thrust::raw_pointer_cast(V_merged_1_device.data()), + thrust::raw_pointer_cast(S_merged_1_device.data()), max_seq_len, + thrust::raw_pointer_cast(seq_len_device.data()), num_heads, head_dim); + + thrust::host_vector V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device); + thrust::host_vector S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device); + + // Compare results + size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, num_S_result_errors_atol_1e_3_rtol_1e_3 = 0; + for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) { + EXPECT_FALSE(std::isnan(float(V_merged_1_host[i]))) << "V_merged_1_host[" << i << "] is nan"; + num_V_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(float(V_merged_0_host[i]), float(V_merged_1_host[i]), 1e-3, 1e-3)); + } + for (size_t i = 0; i < seq_len * num_heads; ++i) { + EXPECT_FALSE(std::isnan(float(S_merged_0_host[i]))) << "S_merged_0_host[" << i << "] is nan"; + EXPECT_FALSE(std::isnan(float(S_merged_1_host[i]))) << "S_merged_1_host[" << i << "] is nan"; + num_S_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(float(S_merged_0_host[i]), float(S_merged_1_host[i]), 1e-3, 1e-3)); + } + float V_result_accuracy = + 1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads * head_dim); + float S_result_accuracy = + 1.0 - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads); + std::cout << "seq_len=" << seq_len << ", num_heads=" << num_heads << ", head_dim=" << head_dim + << ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy + << ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy << std::endl; + + EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed."; + EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed."; +} + template void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads, size_t head_dim, bool sparse_s) { @@ -515,6 +590,12 @@ void TestVariableLengthMergeKernelCorrectness() { } } +template +void TestVariableLengthMergeKernelPaddedCorrectness() { + _TestVariableLengthMergeKernelPaddedCorrectness(8, 1); + _TestVariableLengthMergeKernelPaddedCorrectness(128, 77); +} + template void TestTwoLevelSinglePrefixCascadeDecodeCorrectness() { for (size_t batch_size : {1, 8, 16, 64, 128}) { @@ -563,6 +644,10 @@ TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelCorrectnessTestFP16) { TestVariableLengthMergeKernelCorrectness(); } +TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelPaddedCorrectnessTestFP16) { + TestVariableLengthMergeKernelPaddedCorrectness(); +} + TEST(FlashInferCorrectnessTest, TwoLevelSinglePrefixCascadeDecodeTestFP16) { TestTwoLevelSinglePrefixCascadeDecodeCorrectness(); }