Skip to content

Commit

Permalink
Allow the cascade kernels to be executed using varying sequence lenghts
Browse files Browse the repository at this point in the history
The cascade kernels can take a dynamic sequence length in order to allow
the number of tokens to vary when executed under CUDA graphs.

This is the first step towards implementing CUDA graph support for arbitrary `qo_indptr` contents, as tracked by flashinfer-ai#626.
  • Loading branch information
nandor committed Nov 21, 2024
1 parent 9cba9fb commit 535ea1f
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 31 deletions.
40 changes: 25 additions & 15 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
/*!
* \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of index
* sets at each position might vary.
*
* For CUDA graph support, the kernel can be built with a maximum sequence length and executed
* using a truncated, dynamic sequence length passed through `seq_len_ptr`.
*
* \tparam vec_size The vector size used in the kernel.
* \tparam bdx The blockDim.x used in the kernel.
* \tparam bdy The blockDim.y used in the kernel.
Expand All @@ -336,20 +340,22 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
* \param indptr The start offsets of each position in the variable length array.
* \param v_merged The merged v of index sets union. (n, h, d)
* \param s_merged The merged logsumexp value of index sets union. (n, h)
* \param max_seq_len The maximum sequence length supported by the kernel.
* \param seq_len_ptr The current sequence length (number of positions populated in indptr).
* \param num_heads The number of heads of v.
* \param head_dim The dimension of each head.
* \note s are logsumexp values with base 2.
*/
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
typename DTypeO, typename IdType>
__global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ V,
float* __restrict__ S, IdType* indptr,
DTypeO* __restrict__ v_merged,
float* __restrict__ s_merged,
uint32_t seq_len, uint32_t num_heads) {
__global__ void PersistentVariableLengthMergeStatesKernel(
DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged,
float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr,
uint32_t num_heads) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
constexpr uint32_t head_dim = vec_size * bdx;
Expand Down Expand Up @@ -437,10 +443,13 @@ template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stage
typename DTypeO, typename IdType>
__global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ V, IdType* indptr,
DTypeO* __restrict__ v_sum,
uint32_t seq_len, uint32_t num_heads) {
uint32_t max_seq_len,
uint32_t* __restrict__ seq_len_ptr,
uint32_t num_heads) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
constexpr uint32_t head_dim = vec_size * bdx;
Expand Down Expand Up @@ -641,8 +650,9 @@ cudaError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uin

template <typename DTypeIn, typename DTypeO, typename IdType>
cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged,
float* s_merged, uint32_t seq_len, uint32_t num_heads,
uint32_t head_dim, cudaStream_t stream = nullptr) {
float* s_merged, uint32_t max_seq_len, uint32_t* seq_len,
uint32_t num_heads, uint32_t head_dim,
cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
Expand All @@ -661,11 +671,11 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
DTypeIn, DTypeO, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms));

dim3 nblks(num_sms * num_blocks_per_sm);
dim3 nthrs(bdx, bdy);
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads};
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &max_seq_len, &seq_len, &num_heads};
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
Expand All @@ -674,9 +684,9 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
}

template <typename DTypeIn, typename DTypeO, typename IdType>
cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, uint32_t seq_len,
uint32_t num_heads, uint32_t head_dim,
cudaStream_t stream = nullptr) {
cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum,
uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads,
uint32_t head_dim, cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
Expand All @@ -694,11 +704,11 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum
DTypeIn, DTypeO, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms));

dim3 nblks(num_sms * num_blocks_per_sm);
dim3 nthrs(bdx, bdy);
void* args[] = {&v, &indptr, &v_sum, &seq_len, &num_heads};
void* args[] = {&v, &indptr, &v_sum, &max_seq_len, &seq_len, &num_heads};
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
Expand Down
12 changes: 6 additions & 6 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -764,12 +764,12 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::Par
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM, stream));
params.paged_kv.batch_size, nullptr,
num_qo_heads, HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, o, params.o_indptr,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM, stream));
params.paged_kv.batch_size, nullptr,
num_qo_heads, HEAD_DIM, stream));
}
}
});
Expand Down Expand Up @@ -1087,8 +1087,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(typename AttentionVariant::
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM_CKV, stream));
params.paged_kv.batch_size, nullptr,
num_qo_heads, HEAD_DIM_CKV, stream));
}
});
return cudaSuccess;
Expand Down
18 changes: 10 additions & 8 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2199,11 +2199,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, num_qo_heads, HEAD_DIM,
stream));
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
}
}
}
Expand Down Expand Up @@ -2300,11 +2301,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, num_qo_heads, HEAD_DIM,
stream));
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
}
}
}
Expand Down
89 changes: 87 additions & 2 deletions src/test_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
thrust::raw_pointer_cast(S_ragged_device.data()),
thrust::raw_pointer_cast(indptr_device.data()),
thrust::raw_pointer_cast(V_merged_1_device.data()),
thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, num_heads,
head_dim);
thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, nullptr,
num_heads, head_dim);

thrust::host_vector<T> V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device);
thrust::host_vector<float> S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device);
Expand Down Expand Up @@ -133,6 +133,81 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed.";
}

template <typename T>
void _TestVariableLengthMergeKernelPaddedCorrectness(size_t max_seq_len, size_t seq_len) {
ASSERT_LE(seq_len, max_seq_len);

const size_t num_heads = 4;
const size_t head_dim = 64;
const uint32_t max_num_index_sets = 512;

std::vector<int32_t> lengths(max_seq_len);
utils::vec_randint_(lengths, 1, max_num_index_sets);
std::vector<int32_t> indptr(max_seq_len + 1, 0);
for (size_t i = 0; i < seq_len; ++i) {
indptr[i + 1] = indptr[i] + lengths[i];
}

uint32_t last_indptr = indptr[seq_len];
std::vector<T> V_ragged_host(last_indptr * num_heads * head_dim);
std::vector<float> S_ragged_host(last_indptr * num_heads);

utils::vec_normal_(V_ragged_host);
utils::vec_uniform_(S_ragged_host, -10, 10);

thrust::device_vector<T> V_ragged_device(V_ragged_host);
thrust::device_vector<float> S_ragged_device(S_ragged_host);
thrust::device_vector<int32_t> indptr_device(indptr);
thrust::device_vector<T> V_merged_0_device(max_seq_len * num_heads * head_dim);
thrust::device_vector<T> V_merged_1_device(max_seq_len * num_heads * head_dim);
thrust::device_vector<float> S_merged_0_device(max_seq_len * num_heads);
thrust::device_vector<float> S_merged_1_device(max_seq_len * num_heads);
thrust::device_vector<uint32_t> seq_len_device(
std::vector<uint32_t>{static_cast<uint32_t>(seq_len)});

// Reference: use VariableLengthMergeStates on the precisely-sized input.
VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()),
thrust::raw_pointer_cast(S_ragged_device.data()),
thrust::raw_pointer_cast(indptr_device.data()),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, nullptr,
num_heads, head_dim);
// Expected: use VariableLengthMergeStates on a padded input
VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()),
thrust::raw_pointer_cast(S_ragged_device.data()),
thrust::raw_pointer_cast(indptr_device.data()),
thrust::raw_pointer_cast(V_merged_1_device.data()),
thrust::raw_pointer_cast(S_merged_1_device.data()), max_seq_len,
thrust::raw_pointer_cast(seq_len_device.data()), num_heads, head_dim);

thrust::host_vector<T> V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device);
thrust::host_vector<float> S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device);

// Compare results
size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, num_S_result_errors_atol_1e_3_rtol_1e_3 = 0;
for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) {
EXPECT_FALSE(std::isnan(float(V_merged_1_host[i]))) << "V_merged_1_host[" << i << "] is nan";
num_V_result_errors_atol_1e_3_rtol_1e_3 +=
(!utils::isclose(float(V_merged_0_host[i]), float(V_merged_1_host[i]), 1e-3, 1e-3));
}
for (size_t i = 0; i < seq_len * num_heads; ++i) {
EXPECT_FALSE(std::isnan(float(S_merged_0_host[i]))) << "S_merged_0_host[" << i << "] is nan";
EXPECT_FALSE(std::isnan(float(S_merged_1_host[i]))) << "S_merged_1_host[" << i << "] is nan";
num_S_result_errors_atol_1e_3_rtol_1e_3 +=
(!utils::isclose(float(S_merged_0_host[i]), float(S_merged_1_host[i]), 1e-3, 1e-3));
}
float V_result_accuracy =
1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads * head_dim);
float S_result_accuracy =
1.0 - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads);
std::cout << "seq_len=" << seq_len << ", num_heads=" << num_heads << ", head_dim=" << head_dim
<< ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy
<< ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy << std::endl;

EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed.";
EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed.";
}

template <typename T>
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
size_t head_dim, bool sparse_s) {
Expand Down Expand Up @@ -515,6 +590,12 @@ void TestVariableLengthMergeKernelCorrectness() {
}
}

template <typename T>
void TestVariableLengthMergeKernelPaddedCorrectness() {
_TestVariableLengthMergeKernelPaddedCorrectness<T>(8, 1);
_TestVariableLengthMergeKernelPaddedCorrectness<T>(128, 77);
}

template <typename T>
void TestTwoLevelSinglePrefixCascadeDecodeCorrectness() {
for (size_t batch_size : {1, 8, 16, 64, 128}) {
Expand Down Expand Up @@ -563,6 +644,10 @@ TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelCorrectnessTestFP16) {
TestVariableLengthMergeKernelCorrectness<half>();
}

TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelPaddedCorrectnessTestFP16) {
TestVariableLengthMergeKernelPaddedCorrectness<half>();
}

TEST(FlashInferCorrectnessTest, TwoLevelSinglePrefixCascadeDecodeTestFP16) {
TestTwoLevelSinglePrefixCascadeDecodeCorrectness<half>();
}
Expand Down

0 comments on commit 535ea1f

Please sign in to comment.