Skip to content

Commit

Permalink
bugfix: fix the decode kernel segfault in cudagraph mode (#368)
Browse files Browse the repository at this point in the history
The `begin_forward` function in decode attention wrappers sometimes
triggers segfault, this PR fixes the issue.
  • Loading branch information
yzh119 authored Jul 11, 2024
1 parent 4f0a9f9 commit c69cfab
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
@@ -193,20 +193,6 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
return cudaSuccess;
}

/*!
* \brief A lightweight wrapper to create a vector from pointer to pre-allocated memory
*/
template <typename T>
struct vec_from_ptr {
T* data;
size_t size;
vec_from_ptr(T* data) : data(data), size(0) {}
T operator[](size_t idx) const { return data[idx]; }
T& operator[](size_t idx) { return data[idx]; }
T back() const { return data[size - 1]; }
void push_back(T val) { data[size++] = val; }
};

/*!
* \brief Partition Paged KV-Cache into multiple chunks on KV sequence length
* \tparam IdType A template type indicates the index data type
@@ -226,11 +212,9 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo(
IdType* chunk_indptr_h, IdType* batch_idx_map_h, IdType* chunk_start_pos_h,
IdType* seq_lens_before_partition_h, bool* block_valid_mask_h, void* device_buffer,
void* host_buffer, size_t num_bytes_to_copy, cudaStream_t stream = nullptr) {
vec_from_ptr<IdType> new_page_indptr_vec(new_page_indptr_h),
new_last_page_len_vec(new_last_page_len_h), chunk_indptr_vec(chunk_indptr_h),
batch_idx_map_vec(batch_idx_map_h), chunk_start_pos_vec(chunk_start_pos_h),
seq_lens_before_partition_vec(seq_lens_before_partition_h);
vec_from_ptr<bool> block_valid_mask_vec(block_valid_mask_h);
std::vector<IdType> new_page_indptr_vec, new_last_page_len_vec, chunk_indptr_vec,
batch_idx_map_vec, chunk_start_pos_vec, seq_lens_before_partition_vec;
std::vector<bool> block_valid_mask_vec;

new_page_indptr_vec.push_back(0);
chunk_indptr_vec.push_back(0);
@@ -267,8 +251,20 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo(
}
}
}
std::fill(new_page_indptr_h + new_page_indptr_vec.size, new_page_indptr_h + padded_batch_size + 1,
new_page_indptr_vec.back());
IdType last_page_indptr = new_page_indptr_vec.back();
while (new_page_indptr_vec.size() < padded_batch_size + 1) {
new_page_indptr_vec.push_back(last_page_indptr);
}
std::copy(new_page_indptr_vec.begin(), new_page_indptr_vec.end(), new_page_indptr_h);
std::copy(new_last_page_len_vec.begin(), new_last_page_len_vec.end(), new_last_page_len_h);
std::copy(chunk_indptr_vec.begin(), chunk_indptr_vec.end(), chunk_indptr_h);
std::copy(batch_idx_map_vec.begin(), batch_idx_map_vec.end(), batch_idx_map_h);
std::copy(chunk_start_pos_vec.begin(), chunk_start_pos_vec.end(), chunk_start_pos_h);
std::copy(seq_lens_before_partition_vec.begin(), seq_lens_before_partition_vec.end(),
seq_lens_before_partition_h);
if (block_valid_mask_h != nullptr) {
std::copy(block_valid_mask_vec.begin(), block_valid_mask_vec.end(), block_valid_mask_h);
}

FLASHINFER_CUDA_CALL(cudaMemcpyAsync(device_buffer, host_buffer, num_bytes_to_copy,
cudaMemcpyHostToDevice, stream));

0 comments on commit c69cfab

Please sign in to comment.