From c69cfabc540e4a7edd991713df10d575ff3b0c21 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 10 Jul 2024 23:16:48 -0700 Subject: [PATCH] bugfix: fix the decode kernel segfault in cudagraph mode (#368) The `begin_forward` function in decode attention wrappers sometimes triggers segfault, this PR fixes the issue. --- include/flashinfer/attention/handler.cuh | 38 +++++++++++------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 632ccb85..6b2b72ee 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -193,20 +193,6 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( return cudaSuccess; } -/*! - * \brief A lightweight wrapper to create a vector from pointer to pre-allocated memory - */ -template -struct vec_from_ptr { - T* data; - size_t size; - vec_from_ptr(T* data) : data(data), size(0) {} - T operator[](size_t idx) const { return data[idx]; } - T& operator[](size_t idx) { return data[idx]; } - T back() const { return data[size - 1]; } - void push_back(T val) { data[size++] = val; } -}; - /*! * \brief Partition Paged KV-Cache into multiple chunks on KV sequence length * \tparam IdType A template type indicates the index data type @@ -226,11 +212,9 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo( IdType* chunk_indptr_h, IdType* batch_idx_map_h, IdType* chunk_start_pos_h, IdType* seq_lens_before_partition_h, bool* block_valid_mask_h, void* device_buffer, void* host_buffer, size_t num_bytes_to_copy, cudaStream_t stream = nullptr) { - vec_from_ptr new_page_indptr_vec(new_page_indptr_h), - new_last_page_len_vec(new_last_page_len_h), chunk_indptr_vec(chunk_indptr_h), - batch_idx_map_vec(batch_idx_map_h), chunk_start_pos_vec(chunk_start_pos_h), - seq_lens_before_partition_vec(seq_lens_before_partition_h); - vec_from_ptr block_valid_mask_vec(block_valid_mask_h); + std::vector new_page_indptr_vec, new_last_page_len_vec, chunk_indptr_vec, + batch_idx_map_vec, chunk_start_pos_vec, seq_lens_before_partition_vec; + std::vector block_valid_mask_vec; new_page_indptr_vec.push_back(0); chunk_indptr_vec.push_back(0); @@ -267,8 +251,20 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo( } } } - std::fill(new_page_indptr_h + new_page_indptr_vec.size, new_page_indptr_h + padded_batch_size + 1, - new_page_indptr_vec.back()); + IdType last_page_indptr = new_page_indptr_vec.back(); + while (new_page_indptr_vec.size() < padded_batch_size + 1) { + new_page_indptr_vec.push_back(last_page_indptr); + } + std::copy(new_page_indptr_vec.begin(), new_page_indptr_vec.end(), new_page_indptr_h); + std::copy(new_last_page_len_vec.begin(), new_last_page_len_vec.end(), new_last_page_len_h); + std::copy(chunk_indptr_vec.begin(), chunk_indptr_vec.end(), chunk_indptr_h); + std::copy(batch_idx_map_vec.begin(), batch_idx_map_vec.end(), batch_idx_map_h); + std::copy(chunk_start_pos_vec.begin(), chunk_start_pos_vec.end(), chunk_start_pos_h); + std::copy(seq_lens_before_partition_vec.begin(), seq_lens_before_partition_vec.end(), + seq_lens_before_partition_h); + if (block_valid_mask_h != nullptr) { + std::copy(block_valid_mask_vec.begin(), block_valid_mask_vec.end(), block_valid_mask_h); + } FLASHINFER_CUDA_CALL(cudaMemcpyAsync(device_buffer, host_buffer, num_bytes_to_copy, cudaMemcpyHostToDevice, stream));