Skip to content

Commit

Permalink
feat: allow the cascade kernels to be executed using varying sequence…
Browse files Browse the repository at this point in the history
… lenghts (#627)

The cascade kernels can take a dynamic sequence length in order to allow
the number of tokens to vary when executed under CUDA graphs.

This is the first step towards implementing CUDA graph support for
arbitrary `qo_indptr` contents, as tracked by #626.
  • Loading branch information
nandor authored Nov 23, 2024
1 parent f5842b8 commit 92ac440
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 31 deletions.
40 changes: 25 additions & 15 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
/*!
* \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of index
* sets at each position might vary.
*
* For CUDA graph support, the kernel can be built with a maximum sequence length and executed
* using a truncated, dynamic sequence length passed through `seq_len_ptr`.
*
* \tparam vec_size The vector size used in the kernel.
* \tparam bdx The blockDim.x used in the kernel.
* \tparam bdy The blockDim.y used in the kernel.
Expand All @@ -336,20 +340,22 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
* \param indptr The start offsets of each position in the variable length array.
* \param v_merged The merged v of index sets union. (n, h, d)
* \param s_merged The merged logsumexp value of index sets union. (n, h)
* \param max_seq_len The maximum sequence length supported by the kernel.
* \param seq_len_ptr The current sequence length (number of positions populated in indptr).
* \param num_heads The number of heads of v.
* \param head_dim The dimension of each head.
* \note s are logsumexp values with base 2.
*/
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
typename DTypeO, typename IdType>
__global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ V,
float* __restrict__ S, IdType* indptr,
DTypeO* __restrict__ v_merged,
float* __restrict__ s_merged,
uint32_t seq_len, uint32_t num_heads) {
__global__ void PersistentVariableLengthMergeStatesKernel(
DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged,
float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr,
uint32_t num_heads) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
constexpr uint32_t head_dim = vec_size * bdx;
Expand Down Expand Up @@ -437,10 +443,13 @@ template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stage
typename DTypeO, typename IdType>
__global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ V, IdType* indptr,
DTypeO* __restrict__ v_sum,
uint32_t seq_len, uint32_t num_heads) {
uint32_t max_seq_len,
uint32_t* __restrict__ seq_len_ptr,
uint32_t num_heads) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
constexpr uint32_t head_dim = vec_size * bdx;
Expand Down Expand Up @@ -641,8 +650,9 @@ cudaError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uin

template <typename DTypeIn, typename DTypeO, typename IdType>
cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged,
float* s_merged, uint32_t seq_len, uint32_t num_heads,
uint32_t head_dim, cudaStream_t stream = nullptr) {
float* s_merged, uint32_t max_seq_len, uint32_t* seq_len,
uint32_t num_heads, uint32_t head_dim,
cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
Expand All @@ -661,11 +671,11 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
DTypeIn, DTypeO, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms));

dim3 nblks(num_sms * num_blocks_per_sm);
dim3 nthrs(bdx, bdy);
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads};
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &max_seq_len, &seq_len, &num_heads};
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
Expand All @@ -674,9 +684,9 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
}

template <typename DTypeIn, typename DTypeO, typename IdType>
cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, uint32_t seq_len,
uint32_t num_heads, uint32_t head_dim,
cudaStream_t stream = nullptr) {
cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum,
uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads,
uint32_t head_dim, cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
Expand All @@ -694,11 +704,11 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum
DTypeIn, DTypeO, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms));

dim3 nblks(num_sms * num_blocks_per_sm);
dim3 nthrs(bdx, bdy);
void* args[] = {&v, &indptr, &v_sum, &seq_len, &num_heads};
void* args[] = {&v, &indptr, &v_sum, &max_seq_len, &seq_len, &num_heads};
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
Expand Down
12 changes: 6 additions & 6 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -764,12 +764,12 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::Par
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM, stream));
params.paged_kv.batch_size, nullptr,
num_qo_heads, HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, o, params.o_indptr,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM, stream));
params.paged_kv.batch_size, nullptr,
num_qo_heads, HEAD_DIM, stream));
}
}
});
Expand Down Expand Up @@ -1087,8 +1087,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(typename AttentionVariant::
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM_CKV, stream));
params.paged_kv.batch_size, nullptr,
num_qo_heads, HEAD_DIM_CKV, stream));
}
});
return cudaSuccess;
Expand Down
18 changes: 10 additions & 8 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2199,11 +2199,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, num_qo_heads, HEAD_DIM,
stream));
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
}
}
}
Expand Down Expand Up @@ -2300,11 +2301,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, num_qo_heads, HEAD_DIM,
stream));
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
}
}
}
Expand Down
89 changes: 87 additions & 2 deletions src/test_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
thrust::raw_pointer_cast(S_ragged_device.data()),
thrust::raw_pointer_cast(indptr_device.data()),
thrust::raw_pointer_cast(V_merged_1_device.data()),
thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, num_heads,
head_dim);
thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, nullptr,
num_heads, head_dim);

thrust::host_vector<T> V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device);
thrust::host_vector<float> S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device);
Expand Down Expand Up @@ -133,6 +133,81 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed.";
}

template <typename T>
void _TestVariableLengthMergeKernelPaddedCorrectness(size_t max_seq_len, size_t seq_len) {
ASSERT_LE(seq_len, max_seq_len);

const size_t num_heads = 4;
const size_t head_dim = 64;
const uint32_t max_num_index_sets = 512;

std::vector<int32_t> lengths(max_seq_len);
utils::vec_randint_(lengths, 1, max_num_index_sets);
std::vector<int32_t> indptr(max_seq_len + 1, 0);
for (size_t i = 0; i < seq_len; ++i) {
indptr[i + 1] = indptr[i] + lengths[i];
}

uint32_t last_indptr = indptr[seq_len];
std::vector<T> V_ragged_host(last_indptr * num_heads * head_dim);
std::vector<float> S_ragged_host(last_indptr * num_heads);

utils::vec_normal_(V_ragged_host);
utils::vec_uniform_(S_ragged_host, -10, 10);

thrust::device_vector<T> V_ragged_device(V_ragged_host);
thrust::device_vector<float> S_ragged_device(S_ragged_host);
thrust::device_vector<int32_t> indptr_device(indptr);
thrust::device_vector<T> V_merged_0_device(max_seq_len * num_heads * head_dim);
thrust::device_vector<T> V_merged_1_device(max_seq_len * num_heads * head_dim);
thrust::device_vector<float> S_merged_0_device(max_seq_len * num_heads);
thrust::device_vector<float> S_merged_1_device(max_seq_len * num_heads);
thrust::device_vector<uint32_t> seq_len_device(
std::vector<uint32_t>{static_cast<uint32_t>(seq_len)});

// Reference: use VariableLengthMergeStates on the precisely-sized input.
VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()),
thrust::raw_pointer_cast(S_ragged_device.data()),
thrust::raw_pointer_cast(indptr_device.data()),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, nullptr,
num_heads, head_dim);
// Expected: use VariableLengthMergeStates on a padded input
VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()),
thrust::raw_pointer_cast(S_ragged_device.data()),
thrust::raw_pointer_cast(indptr_device.data()),
thrust::raw_pointer_cast(V_merged_1_device.data()),
thrust::raw_pointer_cast(S_merged_1_device.data()), max_seq_len,
thrust::raw_pointer_cast(seq_len_device.data()), num_heads, head_dim);

thrust::host_vector<T> V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device);
thrust::host_vector<float> S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device);

// Compare results
size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, num_S_result_errors_atol_1e_3_rtol_1e_3 = 0;
for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) {
EXPECT_FALSE(std::isnan(float(V_merged_1_host[i]))) << "V_merged_1_host[" << i << "] is nan";
num_V_result_errors_atol_1e_3_rtol_1e_3 +=
(!utils::isclose(float(V_merged_0_host[i]), float(V_merged_1_host[i]), 1e-3, 1e-3));
}
for (size_t i = 0; i < seq_len * num_heads; ++i) {
EXPECT_FALSE(std::isnan(float(S_merged_0_host[i]))) << "S_merged_0_host[" << i << "] is nan";
EXPECT_FALSE(std::isnan(float(S_merged_1_host[i]))) << "S_merged_1_host[" << i << "] is nan";
num_S_result_errors_atol_1e_3_rtol_1e_3 +=
(!utils::isclose(float(S_merged_0_host[i]), float(S_merged_1_host[i]), 1e-3, 1e-3));
}
float V_result_accuracy =
1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads * head_dim);
float S_result_accuracy =
1.0 - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads);
std::cout << "seq_len=" << seq_len << ", num_heads=" << num_heads << ", head_dim=" << head_dim
<< ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy
<< ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy << std::endl;

EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed.";
EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed.";
}

template <typename T>
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
size_t head_dim, bool sparse_s) {
Expand Down Expand Up @@ -515,6 +590,12 @@ void TestVariableLengthMergeKernelCorrectness() {
}
}

template <typename T>
void TestVariableLengthMergeKernelPaddedCorrectness() {
_TestVariableLengthMergeKernelPaddedCorrectness<T>(8, 1);
_TestVariableLengthMergeKernelPaddedCorrectness<T>(128, 77);
}

template <typename T>
void TestTwoLevelSinglePrefixCascadeDecodeCorrectness() {
for (size_t batch_size : {1, 8, 16, 64, 128}) {
Expand Down Expand Up @@ -563,6 +644,10 @@ TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelCorrectnessTestFP16) {
TestVariableLengthMergeKernelCorrectness<half>();
}

TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelPaddedCorrectnessTestFP16) {
TestVariableLengthMergeKernelPaddedCorrectness<half>();
}

TEST(FlashInferCorrectnessTest, TwoLevelSinglePrefixCascadeDecodeTestFP16) {
TestTwoLevelSinglePrefixCascadeDecodeCorrectness<half>();
}
Expand Down

0 comments on commit 92ac440

Please sign in to comment.