Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the cascade kernels to be executed using varying sequence lenghts #627

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
/*!
* \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of index
* sets at each position might vary.
*
* For CUDA graph support, the kernel can be built with a maximum sequence length and executed
* using a truncated, dynamic sequence length passed through `seq_len_ptr`.
*
* \tparam vec_size The vector size used in the kernel.
* \tparam bdx The blockDim.x used in the kernel.
* \tparam bdy The blockDim.y used in the kernel.
Expand All @@ -336,20 +340,22 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
* \param indptr The start offsets of each position in the variable length array.
* \param v_merged The merged v of index sets union. (n, h, d)
* \param s_merged The merged logsumexp value of index sets union. (n, h)
* \param max_seq_len The maximum sequence length supported by the kernel.
* \param seq_len_ptr The current sequence length (number of positions populated in indptr).
* \param num_heads The number of heads of v.
* \param head_dim The dimension of each head.
* \note s are logsumexp values with base 2.
*/
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
typename DTypeO, typename IdType>
__global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ V,
float* __restrict__ S, IdType* indptr,
DTypeO* __restrict__ v_merged,
float* __restrict__ s_merged,
uint32_t seq_len, uint32_t num_heads) {
__global__ void PersistentVariableLengthMergeStatesKernel(
DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged,
float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option is to make seq_len a cuda array with length 1 and always read seq_len[0]'s value inside kernels, but currently I think having another max_seq_len argument is okay.

uint32_t num_heads) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
constexpr uint32_t head_dim = vec_size * bdx;
Expand Down Expand Up @@ -437,10 +443,13 @@ template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stage
typename DTypeO, typename IdType>
__global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ V, IdType* indptr,
DTypeO* __restrict__ v_sum,
uint32_t seq_len, uint32_t num_heads) {
uint32_t max_seq_len,
uint32_t* __restrict__ seq_len_ptr,
uint32_t num_heads) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
constexpr uint32_t head_dim = vec_size * bdx;
Expand Down Expand Up @@ -641,8 +650,9 @@ cudaError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uin

template <typename DTypeIn, typename DTypeO, typename IdType>
cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged,
float* s_merged, uint32_t seq_len, uint32_t num_heads,
uint32_t head_dim, cudaStream_t stream = nullptr) {
float* s_merged, uint32_t max_seq_len, uint32_t* seq_len,
uint32_t num_heads, uint32_t head_dim,
cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
Expand All @@ -661,11 +671,11 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
DTypeIn, DTypeO, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms));

dim3 nblks(num_sms * num_blocks_per_sm);
dim3 nthrs(bdx, bdy);
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads};
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &max_seq_len, &seq_len, &num_heads};
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
Expand All @@ -674,9 +684,9 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
}

template <typename DTypeIn, typename DTypeO, typename IdType>
cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, uint32_t seq_len,
uint32_t num_heads, uint32_t head_dim,
cudaStream_t stream = nullptr) {
cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum,
uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads,
uint32_t head_dim, cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
Expand All @@ -694,11 +704,11 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum
DTypeIn, DTypeO, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms));

dim3 nblks(num_sms * num_blocks_per_sm);
dim3 nthrs(bdx, bdy);
void* args[] = {&v, &indptr, &v_sum, &seq_len, &num_heads};
void* args[] = {&v, &indptr, &v_sum, &max_seq_len, &seq_len, &num_heads};
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
Expand Down
12 changes: 6 additions & 6 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -764,12 +764,12 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::Par
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM, stream));
params.paged_kv.batch_size, nullptr,
num_qo_heads, HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, o, params.o_indptr,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM, stream));
params.paged_kv.batch_size, nullptr,
num_qo_heads, HEAD_DIM, stream));
}
}
});
Expand Down Expand Up @@ -1087,8 +1087,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(typename AttentionVariant::
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM_CKV, stream));
params.paged_kv.batch_size, nullptr,
num_qo_heads, HEAD_DIM_CKV, stream));
}
});
return cudaSuccess;
Expand Down
18 changes: 10 additions & 8 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2199,11 +2199,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, num_qo_heads, HEAD_DIM,
stream));
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
}
}
}
Expand Down Expand Up @@ -2300,11 +2301,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, num_qo_heads, HEAD_DIM,
stream));
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
}
}
}
Expand Down
89 changes: 87 additions & 2 deletions src/test_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
thrust::raw_pointer_cast(S_ragged_device.data()),
thrust::raw_pointer_cast(indptr_device.data()),
thrust::raw_pointer_cast(V_merged_1_device.data()),
thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, num_heads,
head_dim);
thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, nullptr,
num_heads, head_dim);

thrust::host_vector<T> V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device);
thrust::host_vector<float> S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device);
Expand Down Expand Up @@ -133,6 +133,81 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed.";
}

template <typename T>
void _TestVariableLengthMergeKernelPaddedCorrectness(size_t max_seq_len, size_t seq_len) {
ASSERT_LE(seq_len, max_seq_len);

const size_t num_heads = 4;
const size_t head_dim = 64;
const uint32_t max_num_index_sets = 512;

std::vector<int32_t> lengths(max_seq_len);
utils::vec_randint_(lengths, 1, max_num_index_sets);
std::vector<int32_t> indptr(max_seq_len + 1, 0);
for (size_t i = 0; i < seq_len; ++i) {
indptr[i + 1] = indptr[i] + lengths[i];
}

uint32_t last_indptr = indptr[seq_len];
std::vector<T> V_ragged_host(last_indptr * num_heads * head_dim);
std::vector<float> S_ragged_host(last_indptr * num_heads);

utils::vec_normal_(V_ragged_host);
utils::vec_uniform_(S_ragged_host, -10, 10);

thrust::device_vector<T> V_ragged_device(V_ragged_host);
thrust::device_vector<float> S_ragged_device(S_ragged_host);
thrust::device_vector<int32_t> indptr_device(indptr);
thrust::device_vector<T> V_merged_0_device(max_seq_len * num_heads * head_dim);
thrust::device_vector<T> V_merged_1_device(max_seq_len * num_heads * head_dim);
thrust::device_vector<float> S_merged_0_device(max_seq_len * num_heads);
thrust::device_vector<float> S_merged_1_device(max_seq_len * num_heads);
thrust::device_vector<uint32_t> seq_len_device(
std::vector<uint32_t>{static_cast<uint32_t>(seq_len)});

// Reference: use VariableLengthMergeStates on the precisely-sized input.
VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()),
thrust::raw_pointer_cast(S_ragged_device.data()),
thrust::raw_pointer_cast(indptr_device.data()),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, nullptr,
num_heads, head_dim);
// Expected: use VariableLengthMergeStates on a padded input
VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()),
thrust::raw_pointer_cast(S_ragged_device.data()),
thrust::raw_pointer_cast(indptr_device.data()),
thrust::raw_pointer_cast(V_merged_1_device.data()),
thrust::raw_pointer_cast(S_merged_1_device.data()), max_seq_len,
thrust::raw_pointer_cast(seq_len_device.data()), num_heads, head_dim);

thrust::host_vector<T> V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device);
thrust::host_vector<float> S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device);

// Compare results
size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, num_S_result_errors_atol_1e_3_rtol_1e_3 = 0;
for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) {
EXPECT_FALSE(std::isnan(float(V_merged_1_host[i]))) << "V_merged_1_host[" << i << "] is nan";
num_V_result_errors_atol_1e_3_rtol_1e_3 +=
(!utils::isclose(float(V_merged_0_host[i]), float(V_merged_1_host[i]), 1e-3, 1e-3));
}
for (size_t i = 0; i < seq_len * num_heads; ++i) {
EXPECT_FALSE(std::isnan(float(S_merged_0_host[i]))) << "S_merged_0_host[" << i << "] is nan";
EXPECT_FALSE(std::isnan(float(S_merged_1_host[i]))) << "S_merged_1_host[" << i << "] is nan";
num_S_result_errors_atol_1e_3_rtol_1e_3 +=
(!utils::isclose(float(S_merged_0_host[i]), float(S_merged_1_host[i]), 1e-3, 1e-3));
}
float V_result_accuracy =
1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads * head_dim);
float S_result_accuracy =
1.0 - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads);
std::cout << "seq_len=" << seq_len << ", num_heads=" << num_heads << ", head_dim=" << head_dim
<< ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy
<< ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy << std::endl;

EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed.";
EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed.";
}

template <typename T>
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
size_t head_dim, bool sparse_s) {
Expand Down Expand Up @@ -515,6 +590,12 @@ void TestVariableLengthMergeKernelCorrectness() {
}
}

template <typename T>
void TestVariableLengthMergeKernelPaddedCorrectness() {
_TestVariableLengthMergeKernelPaddedCorrectness<T>(8, 1);
_TestVariableLengthMergeKernelPaddedCorrectness<T>(128, 77);
}

template <typename T>
void TestTwoLevelSinglePrefixCascadeDecodeCorrectness() {
for (size_t batch_size : {1, 8, 16, 64, 128}) {
Expand Down Expand Up @@ -563,6 +644,10 @@ TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelCorrectnessTestFP16) {
TestVariableLengthMergeKernelCorrectness<half>();
}

TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelPaddedCorrectnessTestFP16) {
TestVariableLengthMergeKernelPaddedCorrectness<half>();
}

TEST(FlashInferCorrectnessTest, TwoLevelSinglePrefixCascadeDecodeTestFP16) {
TestTwoLevelSinglePrefixCascadeDecodeCorrectness<half>();
}
Expand Down