From 01a3c70d508dd601b783ae0d9b5d20a62e8d05bd Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 27 May 2024 23:50:33 +0000 Subject: [PATCH 01/14] wip --- CMakeLists.txt | 24 ++--- include/flashinfer/attention/prefill.cuh | 93 +++++++++++-------- include/flashinfer/prefill_attention_decl.cuh | 14 +-- include/flashinfer/utils.cuh | 29 ++++-- python/csrc/batch_prefill.cu | 15 +-- python/csrc/flashinfer_ops.h | 6 +- python/csrc/pytorch_extension_utils.h | 4 +- python/csrc/single_prefill.cu | 8 +- python/flashinfer/utils.py | 5 + python/generate_batch_paged_prefill_inst.py | 9 +- python/generate_batch_ragged_prefill_inst.py | 9 +- python/generate_dispatch_inc.py | 20 ++-- python/generate_single_prefill_inst.py | 10 +- python/literal_map.py | 6 ++ python/setup.py | 28 +++--- 15 files changed, 165 insertions(+), 115 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 38fcd7b7..fd50b4da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,7 +81,7 @@ set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS}) set (KV_LAYOUTS ${FLASHINFER_GEN_KV_LAYOUTS}) set (POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES}) set (ALLOW_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS}) -set (CAUSALS ${FLASHINFER_GEN_CASUALS}) +set (MASK_MODES ${FLASHINFER_GEN_MASK_MODES}) set (DECODE_DTYPES "f16") set (PREFILL_DTYPES "f16") set (DECODE_F8_DTYPES) @@ -104,14 +104,14 @@ message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}") message(STATUS "FLASHINFER_KV_LAYOUTS=${KV_LAYOUTS}") message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}") message(STATUS "FLASHINFER_ALLOW_FP16_QK_REDUCTIONS=${ALLOW_FP16_QK_REDUCTIONS}") -message(STATUS "FLASHINFER_CAUSALS=${CAUSALS}") +message(STATUS "FLASHINFER_MASK_MODES=${MASK_MODES}") file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated) set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc) add_custom_command( OUTPUT ${dispatch_inc_file} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --causals ${CAUSALS} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py COMMENT "Generating additional source file ${generated_dispatch_inc}" VERBATIM @@ -225,9 +225,9 @@ foreach(group_size IN LISTS GROUP_SIZES) foreach(kv_layout IN LISTS KV_LAYOUTS) foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(causal IN LISTS CAUSALS) + foreach(mask_mode IN LISTS MASK_MODES) foreach(dtype IN LISTS PREFILL_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} @@ -237,7 +237,7 @@ foreach(group_size IN LISTS GROUP_SIZES) ) list(APPEND single_prefill_kernels_src ${generated_kernel_src}) endforeach(dtype) - endforeach(causal) + endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) endforeach(kv_layout) @@ -251,10 +251,10 @@ foreach(group_size IN LISTS GROUP_SIZES) foreach(kv_layout IN LISTS KV_LAYOUTS) foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(causal IN LISTS CAUSALS) + foreach(mask_mode IN LISTS MASK_MODES) foreach(dtype IN LISTS PREFILL_DTYPES) foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} @@ -265,7 +265,7 @@ foreach(group_size IN LISTS GROUP_SIZES) list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) endforeach(idtype) endforeach(dtype) - endforeach(causal) + endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) endforeach(kv_layout) @@ -279,10 +279,10 @@ foreach(group_size IN LISTS GROUP_SIZES) foreach(kv_layout IN LISTS KV_LAYOUTS) foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(causal IN LISTS CAUSALS) + foreach(mask_mode IN LISTS MASK_MODES) foreach(dtype IN LISTS PREFILL_DTYPES) foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} @@ -293,7 +293,7 @@ foreach(group_size IN LISTS GROUP_SIZES) list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) endforeach(idtype) endforeach(dtype) - endforeach(causal) + endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) endforeach(kv_layout) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 90edae5e..82dc368f 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -44,6 +44,12 @@ namespace cg = cooperative_groups; using cp_async::SharedMemFillMode; using mma::MMAMode; +enum class MaskMode { + kNone = 0U, // No mask + kCausal = 1U, // Causal mask + kCustom = 2U, // Custom mask +}; + constexpr uint32_t warp_size = 32; namespace { @@ -550,11 +556,12 @@ __device__ __forceinline__ void apply_alibi_bias(const uint32_t qo_idx_base, } } -template __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint32_t chunk_end, + float *custom_mask, DTypeQKAccum (*s_frag)[num_frags_z][8]) { const uint32_t tx = threadIdx.x; #pragma unroll @@ -568,9 +575,10 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_ kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + 8 * (reg_id / 4) + reg_id % 2; const bool out_of_boundary = - (causal ? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end)) + (mask_mode == MaskMode::kCausal ? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end)) : kv_idx >= chunk_end); - s_frag[fx][fz][reg_id] = out_of_boundary ? DTypeQKAccum(-5e4) : s_frag[fx][fz][reg_id]; + s_frag[fx][fz][reg_id] = out_of_boundary ? DTypeQKAccum(-5e4) : s_frag[fx][fz][reg_id] + + DTypeQKAccum(mask_mode == MaskMode::kCustom ? custom_mask[q_idx * kv_len + kv_idx]: 0.f); } } } @@ -870,7 +878,7 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * \brief FlashAttention prefill CUDA kernel for a single request. * \tparam partition_kv Whether to split kv_len into chunks. * \tparam group_size The number of qo heads that maps to a kv head (used in GQA). - * \tparam causal Whether to use causal attention. + * \tparam mask_mode The mask mode used in the attention operation. * \tparam kv_layout The layout of the input tensor. * \tparam pos_encoding_mode The positional encoding mode. * \tparam num_frags_x The number of fragments in x dimension. @@ -892,7 +900,7 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * \param log2_rope_rcp_theta log2(1/(rope_theta)), where rope_theta is the theta * used in RoPE. */ -template @@ -983,7 +991,7 @@ __global__ void SinglePrefillWithKVCacheKernel( v_smem(smem + (num_warps * num_frags_x + num_frags_z) * 16 * head_dim * sizeof(DTypeIn)); const uint32_t num_iterations = ceil_div( - causal ? min(chunk_end - chunk_start, + mask_mode == MaskMode::kCausal ? min(chunk_end - chunk_start, sub_if_greater_or_zero( kv_len - qo_len + ((bx + 1) * num_frags_x * num_warps * 16) / group_size, chunk_start)) @@ -991,7 +999,7 @@ __global__ void SinglePrefillWithKVCacheKernel( 16 * num_frags_z); const uint32_t mask_iteration = - (causal ? min(chunk_end - chunk_start, + (mask_mode == MaskMode::kCausal ? min(chunk_end - chunk_start, sub_if_greater_or_zero( kv_len + (bx * num_warps * num_frags_x * 16) / group_size - qo_len, chunk_start)) @@ -1036,8 +1044,8 @@ __global__ void SinglePrefillWithKVCacheKernel( } // apply mask if (iter >= mask_iteration) { - mask_s( - qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, s_frag); + mask_s( + qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, nullptr, s_frag); } // compute m,d states in online softmax @@ -1097,13 +1105,15 @@ __global__ void SinglePrefillWithKVCacheKernel( } } -template __global__ void BatchPrefillWithRaggedKVCacheKernel( DTypeIn* __restrict__ q, IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k, - DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, IdType* __restrict__ q_offset, + DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, + float* custom_mask, IdType* qk_indptr, + IdType* __restrict__ q_offset, IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, float* __restrict__ lse, uint32_t batch_size, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { @@ -1194,13 +1204,13 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } const uint32_t num_iterations = ceil_div( - (causal ? min(kv_len, + (mask_mode == MaskMode::kCausal ? min(kv_len, kv_len - qo_len + ((tile_idx + 1) * num_frags_x * num_warps * 16) / group_size) : kv_len), 16 * num_frags_z); const uint32_t mask_iteration = - (causal + (mask_mode == MaskMode::kCausal ? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / group_size - qo_len, kv_len) : kv_len) / (16 * num_frags_z); @@ -1251,9 +1261,14 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( qo_idx_base, iter * 16 * num_frags_z, int(kv_len) - int(qo_len), alibi_slopes, s_frag); } // apply mask - if (iter >= mask_iteration) { - mask_s( - qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag); + if constexpr (mask_mode == MaskMode::kCustom) { + mask_s( + qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, custom_mask + qk_indptr[request_idx], s_frag); + } else { + if (iter >= mask_iteration) { + mask_s( + qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, s_frag); + } } // compute m,d states in online softmax @@ -1304,7 +1319,7 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } } -template @@ -1420,14 +1435,14 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( cp_async::commit_group(); const uint32_t num_iterations = ceil_div( - (causal + (mask_mode == MaskMode::kCausal ? min(kv_len, kv_len - qo_len + ((tile_idx + 1) * num_frags_x * num_warps * 16) / aligned_group_size) : kv_len), 16 * num_frags_z); const uint32_t mask_iteration = - (causal + (mask_mode == MaskMode::kCausal ? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / aligned_group_size - qo_len, kv_len) : kv_len) / @@ -1457,8 +1472,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( } // apply mask if (iter >= mask_iteration) { - mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag); + mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, s_frag); } // compute m,d states in online softmax @@ -1523,7 +1538,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( * \param qo_len The length of query and output. * \param kv_len The length of key and value. * \param head_dim The dimension of each head. - * \param causal Whether to use causal attention. + * \param mask_mode The mask mode applied in the attention score. * \param kv_layout The layout of KV Cache. * \param pos_encoding_mode The positional encoding mode. * \param allow_fp16_qk_reduction Whether to allow accumulating q*k^T with fp16. @@ -1533,13 +1548,13 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( template cudaError_t SinglePrefillWithKVCacheWorkEstimation( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, bool causal = true, + uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, MaskMode mask_mode, QKVLayout kv_layout = QKVLayout::kNHD, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool allow_fp16_qk_reduction = false, cudaStream_t stream = nullptr) { - if (kv_len < qo_len && causal) { + if (kv_len < qo_len && mask_mode == MaskMode::kCausal) { std::ostringstream err_msg; - err_msg << "When causal is true, kv_len must be greater than or equal to qo_len, " + err_msg << "When setting mask_mode to kCausal, kv_len must be greater than or equal to qo_len, " << "got kv_len " << kv_len << " and qo_len " << qo_len; throw std::invalid_argument(err_msg.str()); } @@ -1551,8 +1566,8 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( (qo_len * group_size > 64 && head_dim < 256 ? 2 : 1), num_frags_x, {DISPATCH_GQA_GROUP_SIZE( group_size, GROUP_SIZE, - {DISPATCH_CAUSAL( - causal, CAUSAL, {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + {DISPATCH_MASK_MODE( + mask_mode, MASK_MODE, {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t num_frags_y = HEAD_DIM / 16; DISPATCH_POS_ENCODING_MODE( pos_encoding_mode, pos_encoding_mode, @@ -1605,7 +1620,7 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; auto partition_kv_kernel = SinglePrefillWithKVCacheKernel< - /*partition_kv=*/true, GROUP_SIZE, CAUSAL, KV_LAYOUT, + /*partition_kv=*/true, GROUP_SIZE, MASK_MODE, KV_LAYOUT, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; tensor_info_t qkv_info( @@ -1657,7 +1672,7 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( } template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, float* lse, uint32_t num_kv_heads, @@ -1666,9 +1681,9 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); - if (kv_len < qo_len && CAUSAL) { + if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { std::ostringstream err_msg; - err_msg << "When causal is true, kv_len must be greater than or equal to qo_len, got kv_len" + err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal to qo_len, got kv_len" << kv_len << " and qo_len " << qo_len; throw std::invalid_argument(err_msg.str()); } @@ -1713,7 +1728,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* constexpr uint32_t num_threads = num_warps * warp_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; auto partition_kv_kernel = - SinglePrefillWithKVCacheKernel; tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); @@ -1741,7 +1756,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv auto kernel = SinglePrefillWithKVCacheKernel< - /*partition_kv=*/false, GROUP_SIZE, CAUSAL, KV_LAYOUT, pos_encoding_mode, num_frags_x, + /*partition_kv=*/false, GROUP_SIZE, MASK_MODE, KV_LAYOUT, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; void* args[] = {(void*)&q, (void*)&k, @@ -1788,11 +1803,13 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* } template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, + DTypeIn* v, IdType* kv_indptr, + float* mask, IdType* qk_indptr, + IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr) { @@ -1836,7 +1853,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( throw std::invalid_argument(err_msg.str()); } else { auto kernel = - BatchPrefillWithRaggedKVCacheKernel; uint32_t smem_size = @@ -1867,7 +1884,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, @@ -1917,7 +1934,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( throw std::invalid_argument(err_msg.str()); } else { auto kernel = BatchPrefillWithPagedKVCacheKernel< - GROUP_SIZE, PAGE_SIZE, CAUSAL, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, + GROUP_SIZE, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 3e41dc7d..32cca12e 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -29,7 +29,7 @@ namespace flashinfer { template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, float* lse, uint32_t num_kv_heads, @@ -38,7 +38,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* cudaStream_t stream); template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, @@ -48,7 +48,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, @@ -58,7 +58,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( template + MaskMode mask_mode, typename DTypeIn, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, @@ -83,7 +83,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithPagedKVCacheDispatched< page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, pos_encoding_mode, - ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( + ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, o, tmp, lse, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); }); @@ -91,7 +91,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( } template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, @@ -118,7 +118,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithRaggedKVCacheDispatched( + MASK_MODE, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index f4511e56..06b72507 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -137,13 +137,28 @@ throw std::invalid_argument(err_msg.str()); \ } -#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \ - if (causal) { \ - constexpr bool CAUSAL = true; \ - __VA_ARGS__ \ - } else { \ - constexpr bool CAUSAL = false; \ - __VA_ARGS__ \ +#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ + switch (mask_mode) { \ + case MaskMode::kNone: { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCausal: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCustom: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported mask_mode: " << int(mask_mode); \ + throw std::invalid_argument(err_msg.str()); \ + } \ } #define DISPATCH_LAYOUT(layout, LAYOUT, ...) \ diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 47a0df17..f49c4f28 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -54,7 +54,7 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, + torch::Tensor paged_kv_last_page_len, unsigned int mask_mode_value, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); @@ -101,6 +101,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( if (return_lse) { lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } + MaskMode mask_mode = MaskMode(mask_mode_value); DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { @@ -112,7 +113,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( static_cast(paged_kv_last_page_len.data_ptr())); return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_causal(causal, CAUSAL, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( @@ -120,7 +121,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< PageStorage::kIndices, KV_LAYOUT, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, - POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, + POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), @@ -181,7 +182,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, + torch::Tensor kv_indptr, unsigned int mask_mode_value, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); @@ -216,10 +217,12 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } + MaskMode mask_mode = MaskMode(mask_mode_value); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_causal(causal, CAUSAL, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( @@ -227,7 +230,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, int32_t>( + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index d826d71f..e8fdcabe 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -31,7 +31,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc float rope_theta); std::vector single_prefill_with_kv_cache( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, unsigned int mask_mode_value, unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); @@ -117,7 +117,7 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, bool causal, + torch::Tensor paged_kv_last_page_len, unsigned int mask_mode_value, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); @@ -139,7 +139,7 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { void EndForward(); void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, - torch::Tensor v, torch::Tensor kv_indptr, bool causal, + torch::Tensor v, torch::Tensor kv_indptr, unsigned int mask_mode_value, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 21c93899..8d5a6952 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -134,8 +134,8 @@ using namespace flashinfer; _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) -#define DISPATCH_causal(expr, const_expr, ...) \ - _DISPATCH_SWITCH("causal", expr, _DISPATCH_CASES_causal(const_expr, __VA_ARGS__)) +#define DISPATCH_mask_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) inline void check_shape(const torch::Tensor& a, const torch::Tensor& b, const char* a_name, const char* b_name) { diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index e5212170..18368f3d 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -21,7 +21,7 @@ using namespace flashinfer; std::vector single_prefill_with_kv_cache( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, unsigned int mask_mode, unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); @@ -54,10 +54,12 @@ std::vector single_prefill_with_kv_cache( lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); } + MaskMode mask_mode = MaskMode(mask_mode_value); + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_causal(causal, CAUSAL, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { return DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { @@ -65,7 +67,7 @@ std::vector single_prefill_with_kv_cache( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { cudaError_t status = SinglePrefillWithKVCacheDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, CAUSAL>( + ALLOW_FP16_QK_REDUCTION, MASK_MODE>( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 664ac879..8d909291 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -17,6 +17,11 @@ import torch from enum import Enum +class MaskMode(Enum): + NONE = 0 + CAUSAL = 1 + CUSTOM = 2 + class PosEncodingMode(Enum): NONE = 0 diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index 3a061491..58d8b273 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -18,6 +18,7 @@ import re import itertools from literal_map import ( + mask_mode_literal, kv_layout_literal, pos_encoding_mode_literal, dtype_literal, @@ -33,7 +34,7 @@ def get_cu_file_str( kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype_in, dtype_out, idtype, @@ -41,7 +42,7 @@ def get_cu_file_str( num_frags_x_choices = [1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {idtype}* q_offset, paged_kv_t paged_kv, @@ -57,7 +58,7 @@ def get_cu_file_str( head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, - causal=causal, + mask_mode=mask_mode_literal[int(mask_mode)], dtype_in=dtype_literal[dtype_in], dtype_out=dtype_literal[dtype_out], idtype=idtype_literal[idtype], @@ -81,7 +82,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_paged_prefill_group_([0-9]+)_page_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_causal_([a-z]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) path = Path(sys.argv[1]) diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index 65f862f4..c0fb843e 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -17,6 +17,7 @@ import sys import re from literal_map import ( + mask_mode_literal, kv_layout_literal, pos_encoding_mode_literal, dtype_literal, @@ -31,7 +32,7 @@ def get_cu_file_str( kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype_in, dtype_out, idtype, @@ -39,7 +40,7 @@ def get_cu_file_str( num_frags_x_choices = [1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {causal}, {dtype_in}, {dtype_out}, {idtype}>( + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, {idtype}* q_offset, {idtype}* k_rope_pos_offset, {dtype_out}* o, float* tmp, float* lse, @@ -53,7 +54,7 @@ def get_cu_file_str( head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, - causal=causal, + mask_mode=mask_mode_literal[int(mask_mode)], dtype_in=dtype_literal[dtype_in], dtype_out=dtype_literal[dtype_out], idtype=idtype_literal[idtype], @@ -76,7 +77,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_ragged_prefill_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_causal_([a-z]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) path = Path(sys.argv[1]) diff --git a/python/generate_dispatch_inc.py b/python/generate_dispatch_inc.py index 1a2b104b..c99cbe1f 100644 --- a/python/generate_dispatch_inc.py +++ b/python/generate_dispatch_inc.py @@ -16,7 +16,7 @@ import argparse from pathlib import Path -from literal_map import kv_layout_literal, pos_encoding_mode_literal, bool_literal +from literal_map import kv_layout_literal, pos_encoding_mode_literal, bool_literal, mask_mode_literal def get_dispatch_inc_str(args: argparse.Namespace) -> str: @@ -90,15 +90,15 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: {dispatch_allow_fp16_qk_reduction_entries} // EOL """ - # causal - dispatch_causal_entries = "\n".join( + # mask_mode + dispatch_mask_mode_entries = "\n".join( [ - " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(bool_literal[_]) - for _ in args.causals + " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(mask_mode_literal[_]) + for _ in args.mask_modes ] ) - dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(case_var, ...) \\ -{dispatch_causal_entries} + dispatch_mask_mode_str = f"""#define _DISPATCH_CASES_mask_mode(case_var, ...) \\ +{dispatch_mask_mode_entries} // EOL """ @@ -110,7 +110,7 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: dispatch_kv_layouts_str, dispatch_pos_encoding_modes_str, dispatch_allow_fp16_qk_reductions_str, - dispatch_causal_str, + dispatch_mask_mode_str, ] ) @@ -151,11 +151,11 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: help="Allow fp16 qk reductions", ) parser.add_argument( - "--causals", + "--mask_modes", type=lambda x: x if isinstance(x, int) else x.lower() == "true", required=True, nargs="+", - help="Causals", + help="Mask modes", ) args = parser.parse_args() print(args) diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index d3018375..a6357be1 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -16,7 +16,7 @@ import sys import re -from literal_map import kv_layout_literal, pos_encoding_mode_literal, dtype_literal +from literal_map import kv_layout_literal, pos_encoding_mode_literal, dtype_literal, mask_mode_literal from pathlib import Path @@ -26,7 +26,7 @@ def get_cu_file_str( kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype_in, dtype_out, ): @@ -35,7 +35,7 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t SinglePrefillWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {causal}, {dtype_in}, {dtype_out}>( +template cudaError_t SinglePrefillWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, {dtype_out}* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, @@ -48,7 +48,7 @@ def get_cu_file_str( head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, - causal=causal, + mask_mode=mask_mode_literal[int(mask_mode)], dtype_in=dtype_literal[dtype_in], dtype_out=dtype_literal[dtype_out], ) @@ -58,7 +58,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"single_prefill_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_causal_([a-z]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/literal_map.py b/python/literal_map.py index 3587572e..bf4ac679 100644 --- a/python/literal_map.py +++ b/python/literal_map.py @@ -14,6 +14,12 @@ limitations under the License. """ +mask_mode_literal = { + 0: "MaskMode::kNone", + 1: "MaskMode::kCausal", + 2: "MaskMode::kCustom", +} + kv_layout_literal = { 0: "QKVLayout::kNHD", 1: "QKVLayout::kHND", diff --git a/python/setup.py b/python/setup.py index 3eec45df..db9b77be 100644 --- a/python/setup.py +++ b/python/setup.py @@ -73,7 +73,7 @@ def get_instantiation_cu() -> List[str]: allow_fp16_qk_reduction_options = os.environ.get( "FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0" ).split(",") - causal_options = os.environ.get("FLASHINFER_CAUSAL_OPTIONS", "0,1").split(",") + mask_modes = os.environ.get("FLASHINFER_MASK_MODES", "0,1,2").split(",") # dispatch.inc path = root / prefix / "dispatch.inc" write_if_different( @@ -86,7 +86,7 @@ def get_instantiation_cu() -> List[str]: kv_layouts=map(int, kv_layouts), pos_encoding_modes=map(int, pos_encoding_modes), allow_fp16_qk_reductions=map(int, allow_fp16_qk_reduction_options), - causals=map(int, causal_options), + mask_modes=map(int, mask_modes), ) ), ) @@ -217,17 +217,17 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, ) in itertools.product( group_sizes, head_dims, kv_layouts, pos_encoding_modes, allow_fp16_qk_reduction_options, - causal_options, + mask_modes, ): for dtype in prefill_dtypes: - fname = f"single_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_causal_{causal}_dtypein_{dtype}_dtypeout_{dtype}.cu" + fname = f"single_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_modes}_dtypein_{dtype}_dtypeout_{dtype}.cu" files.append(prefix + "/" + fname) content = generate_single_prefill_inst.get_cu_file_str( group_size, @@ -235,7 +235,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype, dtype, ) @@ -249,7 +249,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, idtype, ) in itertools.product( group_sizes, @@ -258,11 +258,11 @@ def get_instantiation_cu() -> List[str]: kv_layouts, pos_encoding_modes, allow_fp16_qk_reduction_options, - causal_options, + mask_modes, idtypes, ): for dtype in prefill_dtypes: - fname = f"batch_paged_prefill_group_{group_size}_page_{page_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_causal_{causal}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + fname = f"batch_paged_prefill_group_{group_size}_page_{page_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_prefill_inst.get_cu_file_str( group_size, @@ -271,7 +271,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype, dtype, idtype, @@ -285,7 +285,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, idtype, ) in itertools.product( group_sizes, @@ -293,11 +293,11 @@ def get_instantiation_cu() -> List[str]: kv_layouts, pos_encoding_modes, allow_fp16_qk_reduction_options, - causal_options, + mask_modes, idtypes, ): for dtype in prefill_dtypes: - fname = f"batch_ragged_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_causal_{causal}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + fname = f"batch_ragged_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_ragged_prefill_inst.get_cu_file_str( group_size, @@ -305,7 +305,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype, dtype, idtype, From 589b0bcbe61e537a691cd588dff92f38f5eb5aaa Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 27 May 2024 23:55:40 +0000 Subject: [PATCH 02/14] wip --- include/flashinfer/attention/prefill.cuh | 4 +- python/csrc/batch_prefill.cu | 78 +++++++++++++++ python/csrc/flashinfer_ops.cu | 1 + python/csrc/flashinfer_ops.h | 6 ++ python/flashinfer/prefill.py | 119 +++++++++++++++++------ 5 files changed, 179 insertions(+), 29 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 82dc368f..b07c75d8 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1112,7 +1112,7 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( DTypeIn* __restrict__ q, IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, - float* custom_mask, IdType* qk_indptr, + float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, float* __restrict__ lse, uint32_t batch_size, float sm_scale, float log2_rope_rcp_scale, @@ -1867,6 +1867,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&k, (void*)&v, (void*)&kv_indptr, + (void*)&mask, + (void*)&qk_indptr, (void*)&q_offset, (void*)&k_rope_pos_offset, (void*)&o, diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index f49c4f28..5b45ea67 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -258,3 +258,81 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( return {o}; } } + +std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardWithMask( + torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, + torch::Tensor kv_indptr, torch::Tensor mask, torch::Tensor qk_indptr, + bool causal, unsigned int pos_encoding_mode, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { + CHECK_INPUT(q); + CHECK_INPUT(qo_indptr); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(kv_indptr); + CHECK_DIM(3, q); // (nnz_qo, H_qo, D) + CHECK_DIM(1, qo_indptr); // (B + 1,) + CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) + CHECK_DIM(3, v); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) + CHECK_DIM(1, kv_indptr); // (B + 1,) + int64_t batch_size = qo_indptr.size(0) - 1; + int64_t nnz_qo = q.size(0); + int64_t num_qo_heads = q.size(1); + int64_t head_dim = q.size(2); + CHECK_EQ(kv_indptr.size(0), batch_size + 1); + int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0); + CHECK_EQ(k.size(0), v.size(0)); + CHECK_EQ(k.size(1), v.size(1)); + CHECK_EQ(k.size(2), v.size(2)); + CHECK_EQ(k.size(2), head_dim); + CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); + // TODO(Zihao): support dispatching to different index data types. + CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); + CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + torch::Tensor o = torch::empty_like(q, q.options()); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); + } + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_causal(causal, CAUSAL, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< + GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + static_cast(k.data_ptr()), static_cast(v.data_ptr()), + static_cast(kv_indptr.data_ptr()), + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 05118bd1..d73d2059 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -71,4 +71,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("update_page_locked_buffer_size", &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward); + .def("forward_with_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardWithMask); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index e8fdcabe..77d80417 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -143,6 +143,12 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + std::vector ForwardWithMask(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, + torch::Tensor v, torch::Tensor kv_indptr, torch::Tensor mask, + torch::Tensor qk_indptr, bool causal, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + float sm_scale, float rope_scale, float rope_theta, + bool return_lse); BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, unsigned int max_workspace_size_in_bytes) : kv_layout_(flashinfer::QKVLayout(layout)), diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index f024a1ac..b5120714 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -593,6 +593,20 @@ def forward_return_lse( ) +def _compute_qk_indptr( + qo_indptr: torch.Tensor, kv_indptr: torch.Tensor +): + if len(qo_indptr) != len(kv_indptr): + raise ValueError("The length of qo_indptr and kv_indptr should be the same.") + qk_indptr = torch.empty_like(qo_indptr) + qk_indptr[0] = 0 + qk_indptr[1:] = torch.cumsum( + (qo_indptr[1:] - qo_indptr[:-1]) * (kv_indptr[1:] - kv_indptr[:-1]), + 0, + ) + return qk_indptr + + class BatchPrefillWithRaggedKVCacheWrapper: r"""Wrapper class for prefill/append attention with ragged (tensor) kv-cache for batch of requests. @@ -672,6 +686,8 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): ) self._qo_indptr = None self._kv_indptr = None + self._mask = None + self._qk_indptr = None def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): r"""Reset the workspace buffer. @@ -691,6 +707,7 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, head_dim: int, + mask: Optional[torch.Tensor] = None, ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -707,6 +724,9 @@ def begin_forward( The number of key/value heads. head_dim : int The dimension of the heads. + mask : Optional[torch.Tensor] + The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))`` + The mask tensor will be added to the attention matrix before softmax. Notes ----- @@ -721,6 +741,9 @@ def begin_forward( batch_size = len(qo_indptr) - 1 self._qo_indptr = qo_indptr self._kv_indptr = kv_indptr + if mask is not None: + self._qk_indptr = _compute_qk_indptr(qo_indptr, kv_indptr) + self._mask = mask self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, @@ -734,6 +757,8 @@ def end_forward(self): r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" self._qo_indptr = None self._kv_indptr = None + self._mask = None + self._qk_indptr = None self._wrapper.end_forward() def forward( @@ -796,20 +821,40 @@ def forward( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - return self._wrapper.forward( - q, - self._qo_indptr, - k, - v, - self._kv_indptr, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - False, - )[0] + if self._mask is None: + return self._wrapper.forward( + q, + self._qo_indptr, + k, + v, + self._kv_indptr, + self._mask, + self._qk_indptr, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] + else: + return self._wrapper.forward_with_mask( + q, + self._qo_indptr, + k, + v, + self._kv_indptr, + self._mask, + self._qk_indptr, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] def forward_return_lse( self, @@ -873,17 +918,35 @@ def forward_return_lse( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - return self._wrapper.forward( - q, - self._qo_indptr, - k, - v, - self._kv_indptr, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - True, - ) + if self._mask is None: + return self._wrapper.forward( + q, + self._qo_indptr, + k, + v, + self._kv_indptr, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) + else: + return self._wrapper.forward_with_mask( + q, + self._qo_indptr, + k, + v, + self._kv_indptr, + self._mask, + self._qk_indptr, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) From ba4ae633f68663478a5d174742211059f71bb9ff Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 00:25:02 +0000 Subject: [PATCH 03/14] wip --- include/flashinfer/attention/prefill.cuh | 104 +++++++------- include/flashinfer/utils.cuh | 42 +++--- python/csrc/batch_prefill.cu | 13 +- python/csrc/flashinfer_ops.cu | 6 +- python/csrc/flashinfer_ops.h | 34 +++-- python/flashinfer/prefill.py | 165 ++++++++++++++++------- python/flashinfer/utils.py | 5 - python/generate_dispatch_inc.py | 11 +- python/generate_single_prefill_inst.py | 7 +- 9 files changed, 245 insertions(+), 142 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index b07c75d8..ede25fc0 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -560,8 +560,7 @@ template __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, - const uint32_t chunk_end, - float *custom_mask, + const uint32_t chunk_end, float* custom_mask, DTypeQKAccum (*s_frag)[num_frags_z][8]) { const uint32_t tx = threadIdx.x; #pragma unroll @@ -575,10 +574,15 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_ kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + 8 * (reg_id / 4) + reg_id % 2; const bool out_of_boundary = - (mask_mode == MaskMode::kCausal ? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end)) - : kv_idx >= chunk_end); - s_frag[fx][fz][reg_id] = out_of_boundary ? DTypeQKAccum(-5e4) : s_frag[fx][fz][reg_id] + - DTypeQKAccum(mask_mode == MaskMode::kCustom ? custom_mask[q_idx * kv_len + kv_idx]: 0.f); + (mask_mode == MaskMode::kCausal + ? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + s_frag[fx][fz][reg_id] = + out_of_boundary + ? DTypeQKAccum(-5e4) + : s_frag[fx][fz][reg_id] + DTypeQKAccum(mask_mode == MaskMode::kCustom + ? custom_mask[q_idx * kv_len + kv_idx] + : 0.f); } } } @@ -991,19 +995,21 @@ __global__ void SinglePrefillWithKVCacheKernel( v_smem(smem + (num_warps * num_frags_x + num_frags_z) * 16 * head_dim * sizeof(DTypeIn)); const uint32_t num_iterations = ceil_div( - mask_mode == MaskMode::kCausal ? min(chunk_end - chunk_start, - sub_if_greater_or_zero( - kv_len - qo_len + ((bx + 1) * num_frags_x * num_warps * 16) / group_size, - chunk_start)) - : chunk_end - chunk_start, + mask_mode == MaskMode::kCausal + ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len - qo_len + ((bx + 1) * num_frags_x * num_warps * 16) / group_size, + chunk_start)) + : chunk_end - chunk_start, 16 * num_frags_z); const uint32_t mask_iteration = - (mask_mode == MaskMode::kCausal ? min(chunk_end - chunk_start, - sub_if_greater_or_zero( - kv_len + (bx * num_warps * num_frags_x * 16) / group_size - qo_len, - chunk_start)) - : (chunk_end - chunk_start)) / + (mask_mode == MaskMode::kCausal + ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len + (bx * num_warps * num_frags_x * 16) / group_size - qo_len, + chunk_start)) + : (chunk_end - chunk_start)) / (16 * num_frags_z); DTypeIn* k_ptr = k + qkv_info.get_kv_elem_offset(chunk_start + ty * 4 + tx / 8, kv_head_idx, @@ -1045,7 +1051,8 @@ __global__ void SinglePrefillWithKVCacheKernel( // apply mask if (iter >= mask_iteration) { mask_s( - qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, nullptr, s_frag); + qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, nullptr, + s_frag); } // compute m,d states in online softmax @@ -1105,15 +1112,15 @@ __global__ void SinglePrefillWithKVCacheKernel( } } -template +template __global__ void BatchPrefillWithRaggedKVCacheKernel( DTypeIn* __restrict__ q, IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k, - DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, - float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, - IdType* __restrict__ q_offset, + DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, float* __restrict__ custom_mask, + IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, float* __restrict__ lse, uint32_t batch_size, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { @@ -1203,11 +1210,12 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } - const uint32_t num_iterations = ceil_div( - (mask_mode == MaskMode::kCausal ? min(kv_len, - kv_len - qo_len + ((tile_idx + 1) * num_frags_x * num_warps * 16) / group_size) - : kv_len), - 16 * num_frags_z); + const uint32_t num_iterations = + ceil_div((mask_mode == MaskMode::kCausal + ? min(kv_len, kv_len - qo_len + + ((tile_idx + 1) * num_frags_x * num_warps * 16) / group_size) + : kv_len), + 16 * num_frags_z); const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal @@ -1263,11 +1271,13 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( // apply mask if constexpr (mask_mode == MaskMode::kCustom) { mask_s( - qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, custom_mask + qk_indptr[request_idx], s_frag); + qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, + custom_mask + qk_indptr[request_idx], s_frag); } else { if (iter >= mask_iteration) { - mask_s( - qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, s_frag); + mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, + s_frag); } } @@ -1319,10 +1329,10 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } } -template +template __global__ void BatchPrefillWithPagedKVCacheKernel( IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, DTypeIn* __restrict__ q, paged_kv_t paged_kv, @@ -1473,7 +1483,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( // apply mask if (iter >= mask_iteration) { mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, s_frag); + num_frags_z>(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, + s_frag); } // compute m,d states in online softmax @@ -1683,7 +1694,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* const float log2_rope_rcp_theta = -std::log2f(rope_theta); if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { std::ostringstream err_msg; - err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal to qo_len, got kv_len" + err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " + "to qo_len, got kv_len" << kv_len << " and qo_len " << qo_len; throw std::invalid_argument(err_msg.str()); } @@ -1756,8 +1768,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv auto kernel = SinglePrefillWithKVCacheKernel< - /*partition_kv=*/false, GROUP_SIZE, MASK_MODE, KV_LAYOUT, pos_encoding_mode, num_frags_x, - num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; + /*partition_kv=*/false, GROUP_SIZE, MASK_MODE, KV_LAYOUT, pos_encoding_mode, + num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; void* args[] = {(void*)&q, (void*)&k, (void*)&v, @@ -1807,12 +1819,10 @@ template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, - float* mask, IdType* qk_indptr, - IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, - float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles, - const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, - const float rope_theta, cudaStream_t stream = nullptr) { + DTypeIn* v, IdType* kv_indptr, float* mask, IdType* qk_indptr, IdType* q_offset, + IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, + const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float sm_scale, + const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; @@ -1936,8 +1946,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( throw std::invalid_argument(err_msg.str()); } else { auto kernel = BatchPrefillWithPagedKVCacheKernel< - GROUP_SIZE, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, - num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; + GROUP_SIZE, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, + num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); FLASHINFER_CUDA_CALL( diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 06b72507..358674a6 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -137,28 +137,28 @@ throw std::invalid_argument(err_msg.str()); \ } -#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ - switch (mask_mode) { \ - case MaskMode::kNone: { \ - constexpr MaskMode MASK_MODE = MaskMode::kNone; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCausal: { \ - constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCustom: { \ - constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ +#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ + switch (mask_mode) { \ + case MaskMode::kNone: { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCausal: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCustom: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ err_msg << "Unsupported mask_mode: " << int(mask_mode); \ - throw std::invalid_argument(err_msg.str()); \ - } \ + throw std::invalid_argument(err_msg.str()); \ + } \ } #define DISPATCH_LAYOUT(layout, LAYOUT, ...) \ diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 5b45ea67..2a574ff5 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -54,9 +54,9 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, unsigned int mask_mode_value, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { + torch::Tensor paged_kv_last_page_len, unsigned int mask_mode_value, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(paged_kv_data); @@ -261,10 +261,9 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardWithMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, torch::Tensor mask, torch::Tensor qk_indptr, - bool causal, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { + torch::Tensor kv_indptr, torch::Tensor mask, torch::Tensor qk_indptr, bool causal, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(k); diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index d73d2059..8e468057 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -22,6 +22,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Single-request decode with KV-Cache operator"); m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, "Single-request prefill with KV-Cache operator, return logsumexp"); + m.def( + "single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask, + "Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp"); m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); m.def("merge_state", &merge_state, "Merge two self-attention states"); m.def("merge_state_in_place", &merge_state_in_place, @@ -63,6 +66,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("update_page_locked_buffer_size", &BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward); + .def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask); py::class_( m, "BatchPrefillWithRaggedKVCachePyTorchWrapper") .def(py::init()) @@ -71,5 +75,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("update_page_locked_buffer_size", &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward); - .def("forward_with_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardWithMask); + .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 77d80417..2944219b 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -31,7 +31,12 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc float rope_theta); std::vector single_prefill_with_kv_cache( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, unsigned int mask_mode_value, + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, + unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + float sm_scale, float rope_scale, float rope_theta, bool return_lse); + +std::vector single_prefill_with_kv_cache_custom_mask( + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, torch::Tensor custom_mask, unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); @@ -117,12 +122,18 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, unsigned int mask_mode_value, + torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); - BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, - unsigned int max_workspace_size_in_bytes) + std::vector ForwardCustomMask( + torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) + BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, + unsigned int max_workspace_size_in_bytes) : kv_layout_(flashinfer::QKVLayout(layout)), handler_(std::make_shared(max_workspace_size_in_bytes)) {} @@ -139,16 +150,17 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { void EndForward(); void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, - torch::Tensor v, torch::Tensor kv_indptr, unsigned int mask_mode_value, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, - bool return_lse); - std::vector ForwardWithMask(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, - torch::Tensor v, torch::Tensor kv_indptr, torch::Tensor mask, - torch::Tensor qk_indptr, bool causal, + torch::Tensor v, torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + std::vector ForwardCustomMask(torch::Tensor q, torch::Tensor qo_indptr, + torch::Tensor k, torch::Tensor v, + torch::Tensor kv_indptr, torch::Tensor custom_mask, + torch::Tensor qk_indptr, + unsigned int pos_encoding_mode, + bool allow_fp16_qk_reduction, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, unsigned int max_workspace_size_in_bytes) : kv_layout_(flashinfer::QKVLayout(layout)), diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index b5120714..0bd37c79 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -57,6 +57,7 @@ def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + custom_mask: Optional[torch.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", @@ -80,8 +81,11 @@ def single_prefill_with_kv_cache( The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is ``HND``. + custom_mask : Optional[torch.Tensor] + The custom mask tensor, shape: ``[qo_len, kv_len]``. causal : bool Whether to apply causal mask to the attention matrix. + This is only effective when :attr:`custom_mask` is not provided. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str @@ -135,26 +139,43 @@ def single_prefill_with_kv_cache( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 - return _kernels.single_prefill_with_kv_cache( - q, - k, - v, - tmp, - causal, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - False, - )[0] + if custom_mask is not None: + return _kernels.single_prefill_with_kv_cache_custom_custom_mask( + q, + k, + v, + tmp, + custom_mask, + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] + else: + return _kernels.single_prefill_with_kv_cache( + q, + k, + v, + tmp, + causal, + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] def single_prefill_with_kv_cache_return_lse( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + custom_mask: Optional[torch.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", @@ -178,8 +199,11 @@ def single_prefill_with_kv_cache_return_lse( The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is ``HND``. + custom_mask : Optional[torch.Tensor] + The custom_mask tensor, shape: ``[qo_len, kv_len]``. causal : bool Whether to apply causal mask to the attention matrix. + This is only effective when :attr:`custom_mask` is not provided. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str @@ -249,20 +273,59 @@ def single_prefill_with_kv_cache_return_lse( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - return _kernels.single_prefill_with_kv_cache( - q, - k, - v, - tmp, - causal, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - True, + if custom_mask is not None: + return _kernels.single_prefill_with_kv_cache_custom_custom_mask( + q, + k, + v, + tmp, + custom_mask, + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) + else: + return _kernels.single_prefill_with_kv_cache( + q, + k, + v, + tmp, + causal, + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) + + +def _compute_page_qk_indptr( + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + page_size: int, +): + if len(qo_indptr) != len(paged_kv_indptr): + raise ValueError( + "The length of qo_indptr and paged_kv_indptr should be the same." + ) + qk_indptr = torch.empty_like(qo_indptr) + qk_indptr[0] = 0 + qk_indptr[1:] = torch.cumsum( + (qo_indptr[1:] - qo_indptr[:-1]) + * ( + (paged_kv_indptr[1:] - paged_kv_indptr[:-1]) * page_size + + paged_kv_last_page_len + ), + 0, ) + return qk_indptr class BatchPrefillWithPagedKVCacheWrapper: @@ -381,6 +444,8 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, head_dim: int, + page_size: int, + custom_mask: Optional[torch.Tensor] = None, ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -402,6 +467,9 @@ def begin_forward( The number of key/value heads. head_dim : int The dimension of the heads. + custom_mask : Optional[torch.Tensor] + The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. + The mask tensor will be applied to the attention matrix before softmax if provided. Notes ----- @@ -418,6 +486,13 @@ def begin_forward( self._paged_kv_indptr = paged_kv_indptr self._paged_kv_indices = paged_kv_indices self._paged_kv_last_page_len = paged_kv_last_page_len + if custom_mask is not None: + self._qk_indptr = _compute_page_qk_indptr( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + page_size, + ) self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, @@ -593,9 +668,7 @@ def forward_return_lse( ) -def _compute_qk_indptr( - qo_indptr: torch.Tensor, kv_indptr: torch.Tensor -): +def _compute_qk_indptr(qo_indptr: torch.Tensor, kv_indptr: torch.Tensor): if len(qo_indptr) != len(kv_indptr): raise ValueError("The length of qo_indptr and kv_indptr should be the same.") qk_indptr = torch.empty_like(qo_indptr) @@ -686,7 +759,7 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): ) self._qo_indptr = None self._kv_indptr = None - self._mask = None + self._custom_mask = None self._qk_indptr = None def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): @@ -707,7 +780,7 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, head_dim: int, - mask: Optional[torch.Tensor] = None, + custom_mask: Optional[torch.Tensor] = None, ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -724,8 +797,8 @@ def begin_forward( The number of key/value heads. head_dim : int The dimension of the heads. - mask : Optional[torch.Tensor] - The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))`` + custom_mask : Optional[torch.Tensor] + The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. The mask tensor will be added to the attention matrix before softmax. Notes @@ -741,9 +814,9 @@ def begin_forward( batch_size = len(qo_indptr) - 1 self._qo_indptr = qo_indptr self._kv_indptr = kv_indptr - if mask is not None: + if custom_mask is not None: self._qk_indptr = _compute_qk_indptr(qo_indptr, kv_indptr) - self._mask = mask + self._custom_mask = custom_mask self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, @@ -757,7 +830,7 @@ def end_forward(self): r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" self._qo_indptr = None self._kv_indptr = None - self._mask = None + self._custom_mask = None self._qk_indptr = None self._wrapper.end_forward() @@ -786,6 +859,7 @@ def forward( The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` causal : bool Whether to apply causal mask to the attention matrix. + This argument is ignored if ``mask`` is provided in :meth:`begin_forward`. pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. @@ -821,15 +895,13 @@ def forward( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - if self._mask is None: + if self._custom_mask is None: return self._wrapper.forward( q, self._qo_indptr, k, v, self._kv_indptr, - self._mask, - self._qk_indptr, causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, @@ -839,15 +911,14 @@ def forward( False, )[0] else: - return self._wrapper.forward_with_mask( + return self._wrapper.forward_custom_custom_mask( q, self._qo_indptr, k, v, self._kv_indptr, - self._mask, + self._custom_mask, self._qk_indptr, - causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, @@ -881,6 +952,7 @@ def forward_return_lse( The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` causal : bool Whether to apply causal mask to the attention matrix. + This argument is ignored if ``mask`` is provided in :meth:`begin_forward`. pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. @@ -918,7 +990,7 @@ def forward_return_lse( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - if self._mask is None: + if self._custom_mask is None: return self._wrapper.forward( q, self._qo_indptr, @@ -934,15 +1006,14 @@ def forward_return_lse( True, ) else: - return self._wrapper.forward_with_mask( + return self._wrapper.forward_custom_custom_mask( q, self._qo_indptr, k, v, self._kv_indptr, - self._mask, + self._custom_mask, self._qk_indptr, - causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 8d909291..664ac879 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -17,11 +17,6 @@ import torch from enum import Enum -class MaskMode(Enum): - NONE = 0 - CAUSAL = 1 - CUSTOM = 2 - class PosEncodingMode(Enum): NONE = 0 diff --git a/python/generate_dispatch_inc.py b/python/generate_dispatch_inc.py index c99cbe1f..923290b4 100644 --- a/python/generate_dispatch_inc.py +++ b/python/generate_dispatch_inc.py @@ -16,7 +16,12 @@ import argparse from pathlib import Path -from literal_map import kv_layout_literal, pos_encoding_mode_literal, bool_literal, mask_mode_literal +from literal_map import ( + kv_layout_literal, + pos_encoding_mode_literal, + bool_literal, + mask_mode_literal, +) def get_dispatch_inc_str(args: argparse.Namespace) -> str: @@ -93,7 +98,9 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: # mask_mode dispatch_mask_mode_entries = "\n".join( [ - " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(mask_mode_literal[_]) + " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format( + mask_mode_literal[_] + ) for _ in args.mask_modes ] ) diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index a6357be1..93f55f14 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -16,7 +16,12 @@ import sys import re -from literal_map import kv_layout_literal, pos_encoding_mode_literal, dtype_literal, mask_mode_literal +from literal_map import ( + kv_layout_literal, + pos_encoding_mode_literal, + dtype_literal, + mask_mode_literal, +) from pathlib import Path From 132d04e00eaf31e5ca324af4a7848e76189a7bee Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 01:05:08 +0000 Subject: [PATCH 04/14] wip --- include/flashinfer/attention/prefill.cuh | 43 ++++++++++++++++++------ python/csrc/batch_prefill.cu | 31 +++++++++++------ python/csrc/single_prefill.cu | 4 ++- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index ede25fc0..b8502355 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -910,7 +910,9 @@ template __global__ void SinglePrefillWithKVCacheKernel( DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, - DTypeOut* __restrict__ o, void* __restrict__ tmp, float* __restrict__ lse, + DTypeOut* __restrict__ o, + float* __restrict__ custom_mask, + void* __restrict__ tmp, float* __restrict__ lse, const tensor_info_t qkv_info, float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); @@ -1049,10 +1051,16 @@ __global__ void SinglePrefillWithKVCacheKernel( alibi_slopes, s_frag); } // apply mask - if (iter >= mask_iteration) { + if constexpr (mask_mode == MaskMode::kCustom) { mask_s( - qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, nullptr, + qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, custom_mask, s_frag); + } else { + if (iter >= mask_iteration) { + mask_s( + qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, nullptr, + s_frag); + } } // compute m,d states in online softmax @@ -1336,7 +1344,10 @@ template paged_kv, - IdType* __restrict__ qo_indptr, IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, + IdType* __restrict__ qo_indptr, + float* __restrict__ custom_mask, + IdType* __restrict__ qk_indptr, + IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { constexpr uint32_t rows_per_warp = 16 / (2 * group_size) * 2; @@ -1481,10 +1492,16 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( qo_idx_base, iter * 16 * num_frags_z, int(kv_len) - int(qo_len), alibi_slopes, s_frag); } // apply mask - if (iter >= mask_iteration) { + if constexpr (mask_mode == MaskMode::kCustom) { mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, + num_frags_z>(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, custom_mask + qk_indptr[request_idx], s_frag); + } else { + if (iter >= mask_iteration) { + mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, + s_frag); + } } // compute m,d states in online softmax @@ -1685,7 +1702,7 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( template -cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, +cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, @@ -1773,6 +1790,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* void* args[] = {(void*)&q, (void*)&k, (void*)&v, + (void*)&custom_mask, (void*)&o, (void*)&tmp, (void*)&lse, @@ -1791,6 +1809,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* void* args[] = {(void*)&q, (void*)&k, (void*)&v, + (void*)&custom_mask, (void*)&o, (void*)&tmp, (void*)&lse, @@ -1819,7 +1838,7 @@ template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, float* mask, IdType* qk_indptr, IdType* q_offset, + DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr) { @@ -1877,7 +1896,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&k, (void*)&v, (void*)&kv_indptr, - (void*)&mask, + (void*)&custom_mask, (void*)&qk_indptr, (void*)&q_offset, (void*)&k_rope_pos_offset, @@ -1899,7 +1918,9 @@ template cudaError_t BatchPrefillWithPagedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, + DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, + float* custom_mask, IdType* qk_indptr, + IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { @@ -1957,6 +1978,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (void*)&q, (void*)&paged_kv, (void*)&qo_indptr, + (void*)&custom_mask, + (void*)&qk_indptr, (void*)&q_offset, (void*)&o, (void*)&tmp, diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 2a574ff5..29304f01 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -54,7 +54,7 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, unsigned int mask_mode_value, + torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); @@ -101,7 +101,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( if (return_lse) { lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } - MaskMode mask_mode = MaskMode(mask_mode_value); + MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { @@ -125,6 +125,8 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, sm_scale, rope_scale, rope_theta, @@ -182,7 +184,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, unsigned int mask_mode_value, unsigned int pos_encoding_mode, + torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); @@ -217,7 +219,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } - MaskMode mask_mode = MaskMode(mask_mode_value); + MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { @@ -235,6 +237,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(kv_indptr.data_ptr()), + /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, @@ -261,7 +264,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardWithMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, torch::Tensor mask, torch::Tensor qk_indptr, bool causal, + torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); @@ -269,16 +272,21 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardW CHECK_INPUT(k); CHECK_INPUT(v); CHECK_INPUT(kv_indptr); + CHECK_INPUT(custom_mask); + CHECK_INPUT(qk_indptr); CHECK_DIM(3, q); // (nnz_qo, H_qo, D) CHECK_DIM(1, qo_indptr); // (B + 1,) CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) CHECK_DIM(3, v); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) CHECK_DIM(1, kv_indptr); // (B + 1,) + CHECK_DIM(1, custom_mask); // (nnz_qk,) + CHECK_DIM(1, qk_indptr); // (B + 1,) int64_t batch_size = qo_indptr.size(0) - 1; int64_t nnz_qo = q.size(0); int64_t num_qo_heads = q.size(1); int64_t head_dim = q.size(2); CHECK_EQ(kv_indptr.size(0), batch_size + 1); + CHECK_EQ(qk_indptr.size(0), batch_size + 1); int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0); CHECK_EQ(k.size(0), v.size(0)); CHECK_EQ(k.size(1), v.size(1)); @@ -286,8 +294,10 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardW CHECK_EQ(k.size(2), head_dim); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32); + qo_indptr = qo_indptr.to(torch::kInt32); + kv_indptr = kv_indptr.to(torch::kInt32); + qk_indptr = qk_indptr.to(torch::kInt32); + custom_mask = custom_mask.to(torch::kFloat32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); torch::Tensor o = torch::empty_like(q, q.options()); @@ -296,10 +306,10 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardW lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } + constexpr MaskMode MASK_MODE = MaskMode::kCustom; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_causal(causal, CAUSAL, [&] { return DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( @@ -307,11 +317,13 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardW return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, int32_t>( + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(kv_indptr.data_ptr()), + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, @@ -323,7 +335,6 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardW return true; }); }); - }); }); }); }); diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 18368f3d..325d8663 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -69,7 +69,9 @@ std::vector single_prefill_with_kv_cache( GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE>( static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(v.data_ptr()), + /*custom_mask=*/nullptr, + static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, From 530bed42d745a97460bb7fe507cde9b9a864c268 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 01:47:55 +0000 Subject: [PATCH 05/14] upd --- include/flashinfer/attention/prefill.cuh | 57 +++--- include/flashinfer/prefill_attention_decl.cuh | 44 ++--- python/csrc/batch_prefill.cu | 170 ++++++++++++++---- python/csrc/single_prefill.cu | 100 +++++++++-- python/generate_batch_paged_prefill_inst.py | 1 + python/generate_batch_ragged_prefill_inst.py | 4 +- python/generate_single_prefill_inst.py | 2 +- src/flashinfer_ops.cuh | 53 +++--- 8 files changed, 306 insertions(+), 125 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index b8502355..7c1fe88c 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -910,11 +910,9 @@ template __global__ void SinglePrefillWithKVCacheKernel( DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, - DTypeOut* __restrict__ o, - float* __restrict__ custom_mask, - void* __restrict__ tmp, float* __restrict__ lse, - const tensor_info_t qkv_info, float sm_scale, - const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { + DTypeOut* __restrict__ o, float* __restrict__ custom_mask, void* __restrict__ tmp, + float* __restrict__ lse, const tensor_info_t qkv_info, + float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= math::log2e; @@ -1053,13 +1051,13 @@ __global__ void SinglePrefillWithKVCacheKernel( // apply mask if constexpr (mask_mode == MaskMode::kCustom) { mask_s( - qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, custom_mask, - s_frag); + qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, + custom_mask, s_frag); } else { if (iter >= mask_iteration) { - mask_s( - qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, nullptr, - s_frag); + mask_s(qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, + chunk_end, nullptr, s_frag); } } @@ -1344,12 +1342,9 @@ template paged_kv, - IdType* __restrict__ qo_indptr, - float* __restrict__ custom_mask, - IdType* __restrict__ qk_indptr, - IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, - float* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float log2_rope_rcp_scale, - float log2_rope_rcp_theta) { + IdType* __restrict__ qo_indptr, float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, + IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, + float* __restrict__ lse, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { constexpr uint32_t rows_per_warp = 16 / (2 * group_size) * 2; constexpr uint32_t aligned_group_size = 16 / rows_per_warp; static_assert(sizeof(DTypeIn) == 2); @@ -1494,12 +1489,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( // apply mask if constexpr (mask_mode == MaskMode::kCustom) { mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, custom_mask + qk_indptr[request_idx], - s_frag); + num_frags_z>(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, + custom_mask + qk_indptr[request_idx], s_frag); } else { if (iter >= mask_iteration) { mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, + num_frags_z>(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, s_frag); } } @@ -1702,11 +1697,11 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( template -cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, - float* tmp, float* lse, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, float sm_scale, - float rope_scale, float rope_theta, - cudaStream_t stream) { +cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, + float* custom_mask, DTypeOut* o, float* tmp, + float* lse, uint32_t num_kv_heads, uint32_t qo_len, + uint32_t kv_len, float sm_scale, float rope_scale, + float rope_theta, cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { @@ -1838,8 +1833,8 @@ template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, - IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, + DTypeIn* v, IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, float* custom_mask, + IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); @@ -1918,12 +1913,10 @@ template cudaError_t BatchPrefillWithPagedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, - float* custom_mask, IdType* qk_indptr, - IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* tmp, - float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { + DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, + paged_kv_t paged_kv, float* custom_mask, + IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 32cca12e..6ef5c7ef 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -31,20 +31,21 @@ namespace flashinfer { template -cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, - float* tmp, float* lse, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, float sm_scale, - float rope_scale, float rope_theta, - cudaStream_t stream); +cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, + float* custom_mask, DTypeOut* o, float* tmp, + float* lse, uint32_t num_kv_heads, uint32_t qo_len, + uint32_t kv_len, float sm_scale, float rope_scale, + float rope_theta, cudaStream_t stream); template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, - float* tmp, float* lse, uint32_t batch_size, uint32_t num_qo_tiles, uint32_t num_kv_heads, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream = nullptr); + DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, + IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size, + uint32_t num_qo_tiles, uint32_t num_kv_heads, float sm_scale, float rope_scale, + float rope_theta, cudaStream_t stream = nullptr); template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* tmp, - float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream); + paged_kv_t paged_kv, float* custom_mask, + IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream); template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { + paged_kv_t paged_kv, float* custom_mask, + IdType* qk_indptr, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -84,8 +86,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( return BatchPrefillWithPagedKVCacheDispatched< page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( - q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, o, tmp, lse, num_qo_tiles, - sm_scale, rope_scale, rope_theta, stream); + q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o, + tmp, lse, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } @@ -95,9 +97,9 @@ template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, - uint32_t batch_size, uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { + IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, + IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_kv_heads, + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -119,9 +121,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( return BatchPrefillWithRaggedKVCacheDispatched( - q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_offset, k_rope_pos_offset, - o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta, - stream); + q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr, + q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, + rope_scale, rope_theta, stream); }); return cudaSuccess; } diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 29304f01..d8c41b00 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -54,9 +54,109 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, bool causal, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, - float rope_theta, bool return_lse) { + torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { + CHECK_INPUT(q); + CHECK_INPUT(qo_indptr); + CHECK_INPUT(paged_kv_data); + CHECK_INPUT(paged_kv_indptr); + CHECK_INPUT(paged_kv_indices); + CHECK_INPUT(paged_kv_last_page_len); + CHECK_DIM(3, q); // (nnz_qo, H_qo, D) + CHECK_DIM(1, qo_indptr); // (B + 1,) + // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND + // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND + CHECK_DIM(5, paged_kv_data); + CHECK_DIM(1, paged_kv_indptr); // (B + 1,) + CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) + CHECK_DIM(1, paged_kv_last_page_len); // (B,) + int64_t batch_size = qo_indptr.size(0) - 1; + int64_t nnz_qo = q.size(0); + int64_t num_qo_heads = q.size(1); + int64_t head_dim = q.size(2); + int64_t num_kv_heads, page_size; + if (kv_layout_ == QKVLayout::kHND) { + num_kv_heads = paged_kv_data.size(2); + page_size = paged_kv_data.size(3); + } else { + page_size = paged_kv_data.size(2); + num_kv_heads = paged_kv_data.size(3); + } + CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); + CHECK_EQ(qo_indptr.size(0), batch_size + 1); + CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1); + CHECK_EQ(paged_kv_last_page_len.size(0), batch_size); + CHECK_EQ(paged_kv_data.size(1), 2); + CHECK_EQ(paged_kv_data.size(4), head_dim); + // TODO(Zihao): support dispatching to different index data types. + CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); + CHECK_EQ(paged_kv_indptr.scalar_type(), torch::kInt32); + CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32); + CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + torch::Tensor o = torch::empty_like(q, q.options()); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); + } + MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, KV_LAYOUT, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, + POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, + int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} + +std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask( + torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(paged_kv_data); @@ -274,13 +374,13 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardW CHECK_INPUT(kv_indptr); CHECK_INPUT(custom_mask); CHECK_INPUT(qk_indptr); - CHECK_DIM(3, q); // (nnz_qo, H_qo, D) - CHECK_DIM(1, qo_indptr); // (B + 1,) - CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) - CHECK_DIM(3, v); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) - CHECK_DIM(1, kv_indptr); // (B + 1,) - CHECK_DIM(1, custom_mask); // (nnz_qk,) - CHECK_DIM(1, qk_indptr); // (B + 1,) + CHECK_DIM(3, q); // (nnz_qo, H_qo, D) + CHECK_DIM(1, qo_indptr); // (B + 1,) + CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) + CHECK_DIM(3, v); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) + CHECK_DIM(1, kv_indptr); // (B + 1,) + CHECK_DIM(1, custom_mask); // (nnz_qk,) + CHECK_DIM(1, qk_indptr); // (B + 1,) int64_t batch_size = qo_indptr.size(0) - 1; int64_t nnz_qo = q.size(0); int64_t num_qo_heads = q.size(1); @@ -310,32 +410,32 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardW DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - static_cast(k.data_ptr()), static_cast(v.data_ptr()), - static_cast(kv_indptr.data_ptr()), - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithRaggedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< + GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + static_cast(k.data_ptr()), static_cast(v.data_ptr()), + static_cast(kv_indptr.data_ptr()), + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCache failed with error ", + cudaGetErrorString(status)); + return true; }); - }); + }); + }); }); }); }); diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 325d8663..4335666e 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -65,17 +65,18 @@ std::vector single_prefill_with_kv_cache( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SinglePrefillWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE>( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - /*custom_mask=*/nullptr, - static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, - torch_current_stream); + cudaError_t status = + SinglePrefillWithKVCacheDispatched( + static_cast(q.data_ptr()), + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + /*custom_mask=*/nullptr, static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, + torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); @@ -94,3 +95,80 @@ std::vector single_prefill_with_kv_cache( return {o}; } } + +std::vector single_prefill_with_kv_cache_custom_mask( + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor custom_mask, torch::Tensor tmp, + unsigned int mask_mode, unsigned int layout, unsigned int pos_encoding_mode, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { + CHECK_INPUT(q); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(custom_mask); + CHECK_DIM(3, q); + CHECK_DIM(3, k); + CHECK_DIM(3, v); + CHECK_DIM(2, custom_mask); + CHECK_SHAPE(k, v); + CHECK_EQ(q.size(2), k.size(2)); + unsigned int head_dim = q.size(2); + unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; + QKVLayout kv_layout = static_cast(layout); + if (kv_layout == QKVLayout::kNHD) { + kv_len = k.size(0); + qo_len = q.size(0); + num_kv_heads = k.size(1); + num_qo_heads = q.size(1); + } else { + kv_len = k.size(1); + qo_len = q.size(1); + num_kv_heads = k.size(0); + num_qo_heads = q.size(0); + } + CHECK_EQ(custom_mask.size(0), qo_len); + CHECK_EQ(custom_mask.size(1), kv_len); + CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + auto o = torch::empty_like(q, q.options()); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); + } + + constexpr MaskMode MASK_MODE = MaskMode::kCustom; + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = SinglePrefillWithKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE>( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + static_cast(custom_mask.data_ptr()), + static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); + }); + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index 58d8b273..5d301dd2 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -46,6 +46,7 @@ def get_cu_file_str( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {idtype}* q_offset, paged_kv_t paged_kv, + float* custom_mask, {idtype}* qk_indptr, {dtype_out}* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index c0fb843e..7eeab91e 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -42,7 +42,9 @@ def get_cu_file_str( [ """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, - {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, {idtype}* q_offset, {idtype}* k_rope_pos_offset, + {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, + float* custom_mask, {idtype}* qk_indptr, + {idtype}* q_offset, {idtype}* k_rope_pos_offset, {dtype_out}* o, float* tmp, float* lse, uint32_t batch_size, uint32_t num_qo_tiles, uint32_t num_kv_heads, float sm_scale, float rope_scale, diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index 93f55f14..7ffad989 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -41,7 +41,7 @@ def get_cu_file_str( namespace flashinfer {{ template cudaError_t SinglePrefillWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( - {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, {dtype_out}* o, + {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, float* custom_mask, {dtype_out}* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 1924f20e..73829278 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -56,21 +56,23 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu cudaStream_t stream = nullptr) { const uint32_t group_size = num_qo_heads / num_kv_heads; const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {DISPATCH_group_size( group_size, GROUP_SIZE, - {DISPATCH_causal( - causal, CAUSAL, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, {DISPATCH_head_dim(head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode( pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { return SinglePrefillWithKVCacheDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, CAUSAL>( - q, k, v, o, tmp, lse, num_kv_heads, qo_len, kv_len, - sm_scale, rope_scale, rope_theta, stream); + ALLOW_FP16_QK_REDUCTION, MASK_MODE>( + q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, + num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, + rope_theta, stream); })})})})})}); return cudaSuccess; } @@ -85,24 +87,25 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_kv_layout( kv_layout, KV_LAYOUT, {DISPATCH_group_size( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_head_dim( head_dim, HEAD_DIM, - {DISPATCH_causal( - causal, CAUSAL, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, {DISPATCH_pos_encoding_mode( pos_encoding_mode, pos_encoding_mode, {DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { return BatchPrefillWithRaggedKVCacheWrapperDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, pos_encoding_mode, - ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, q_offset, k_rope_pos_offset, - o, lse, batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, - stream); + ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, + batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); })})})})})}); return cudaSuccess; } @@ -119,23 +122,25 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_group_size( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_head_dim( head_dim, HEAD_DIM, - {DISPATCH_causal(causal, CAUSAL, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, pos_encoding_mode, - {DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, - {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { - return BatchPrefillWithPagedKVCacheWrapperDispatched< - page_storage, kv_layout, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, - pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, - DTypeIn, DTypeOut, IdType>(handler, q, qo_indptr, q_offset, - paged_kv, o, lse, sm_scale, - rope_scale, rope_theta, stream); - })})})})})}); + {DISPATCH_mask_mode(mask_mode, MASK_MODE, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, pos_encoding_mode, + {DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, + {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { + return BatchPrefillWithPagedKVCacheWrapperDispatched< + page_storage, kv_layout, PAGE_SIZE, GROUP_SIZE, + HEAD_DIM, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, + MASK_MODE, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, q_offset, paged_kv, o, lse, + sm_scale, rope_scale, rope_theta, stream); + })})})})})}); return cudaSuccess; } From 07034d0fbe0ebbd70fa94e069fefecb809f2f648 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 02:02:27 +0000 Subject: [PATCH 06/14] upd --- cmake/config.cmake | 2 +- include/flashinfer/attention/mask.cuh | 29 +++++++++++++++++++ include/flashinfer/attention/prefill.cuh | 15 ++++------ include/flashinfer/prefill_attention_decl.cuh | 11 +++---- python/generate_dispatch_inc.py | 2 +- src/flashinfer_ops.cuh | 4 +-- src/utils.h | 4 +-- 7 files changed, 46 insertions(+), 21 deletions(-) create mode 100644 include/flashinfer/attention/mask.cuh diff --git a/cmake/config.cmake b/cmake/config.cmake index 10a1b843..c2fd48e9 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -24,7 +24,7 @@ set(FLASHINFER_GEN_HEAD_DIMS 64 128 256) set(FLASHINFER_GEN_KV_LAYOUTS 0 1) set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2) set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true") -set(FLASHINFER_GEN_CASUALS "false" "true") +set(FLASHINFER_GEN_MASK_MODES 0 1) # Set target cuda architectures for tests/benchmarks, defaults to native. # "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU. diff --git a/include/flashinfer/attention/mask.cuh b/include/flashinfer/attention/mask.cuh new file mode 100644 index 00000000..9adc6c26 --- /dev/null +++ b/include/flashinfer/attention/mask.cuh @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_MASK_CUH_ +#define FLASHINFER_ATTENTION_MASK_CUH_ + +namespace flashinfer { + +enum class MaskMode { + kNone = 0U, // No mask + kCausal = 1U, // Causal mask + kCustom = 2U, // Custom mask +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_MASK_CUH_ \ No newline at end of file diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 7c1fe88c..b0efe1bc 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -36,6 +36,7 @@ #include "../utils.cuh" #include "cascade.cuh" #include "handler.cuh" +#include "mask.cuh" #include "state.cuh" namespace flashinfer { @@ -44,12 +45,6 @@ namespace cg = cooperative_groups; using cp_async::SharedMemFillMode; using mma::MMAMode; -enum class MaskMode { - kNone = 0U, // No mask - kCausal = 1U, // Causal mask - kCustom = 2U, // Custom mask -}; - constexpr uint32_t warp_size = 32; namespace { @@ -1829,12 +1824,12 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* } template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, float* custom_mask, - IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, + DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, + IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); @@ -1910,7 +1905,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 6ef5c7ef..a59777c0 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -21,6 +21,7 @@ #include #include "attention/handler.cuh" +#include "attention/mask.cuh" #include "layout.cuh" #include "page.cuh" #include "pos_enc.cuh" @@ -29,7 +30,7 @@ namespace flashinfer { template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, float* tmp, @@ -38,7 +39,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* float rope_theta, cudaStream_t stream); template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, @@ -49,7 +50,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, @@ -59,7 +60,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( template + MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, float* custom_mask, @@ -93,7 +94,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( } template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, diff --git a/python/generate_dispatch_inc.py b/python/generate_dispatch_inc.py index 923290b4..03ec819d 100644 --- a/python/generate_dispatch_inc.py +++ b/python/generate_dispatch_inc.py @@ -159,7 +159,7 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: ) parser.add_argument( "--mask_modes", - type=lambda x: x if isinstance(x, int) else x.lower() == "true", + type=int, required=True, nargs="+", help="Mask modes", diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 73829278..dc48fb90 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -137,8 +137,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( page_storage, kv_layout, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, q_offset, paged_kv, o, lse, + handler, q, qo_indptr, q_offset, paged_kv, + /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, o, lse, sm_scale, rope_scale, rope_theta, stream); })})})})})}); return cudaSuccess; diff --git a/src/utils.h b/src/utils.h index b1570450..d77ebecb 100644 --- a/src/utils.h +++ b/src/utils.h @@ -68,8 +68,8 @@ _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) -#define DISPATCH_causal(expr, const_expr, ...) \ - _DISPATCH_SWITCH("causal", expr, _DISPATCH_CASES_causal(const_expr, __VA_ARGS__)) +#define DISPATCH_mask_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) namespace utils { From 17180650264dcacaf8501af833ee22bc03cb850f Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 02:04:42 +0000 Subject: [PATCH 07/14] fix typo --- python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index db9b77be..fd524f63 100644 --- a/python/setup.py +++ b/python/setup.py @@ -227,7 +227,7 @@ def get_instantiation_cu() -> List[str]: mask_modes, ): for dtype in prefill_dtypes: - fname = f"single_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_modes}_dtypein_{dtype}_dtypeout_{dtype}.cu" + fname = f"single_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" files.append(prefix + "/" + fname) content = generate_single_prefill_inst.get_cu_file_str( group_size, From b8e7f36010534ede0f639112e3b495aeaa287e98 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 02:07:03 +0000 Subject: [PATCH 08/14] fix typo --- python/csrc/flashinfer_ops.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 2944219b..f3cf242f 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -131,9 +131,9 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, - float rope_scale, float rope_theta, bool return_lse) - BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, - unsigned int max_workspace_size_in_bytes) + float rope_scale, float rope_theta, bool return_lse); + BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, + unsigned int max_workspace_size_in_bytes) : kv_layout_(flashinfer::QKVLayout(layout)), handler_(std::make_shared(max_workspace_size_in_bytes)) {} From 332333c4c32c02f0f853e67ea4a9c846f3749c48 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 02:08:39 +0000 Subject: [PATCH 09/14] fix typo --- python/csrc/flashinfer_ops.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 8e468057..2653d913 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -65,8 +65,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward) .def("update_page_locked_buffer_size", &BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward); - .def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask); + .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward) + .def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask); py::class_( m, "BatchPrefillWithRaggedKVCachePyTorchWrapper") .def(py::init()) @@ -74,6 +74,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward) .def("update_page_locked_buffer_size", &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward); - .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); + .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) + .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); } From ba0463b15e63e3806f8989567f701e88f69284b5 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 02:19:13 +0000 Subject: [PATCH 10/14] bugfix --- python/csrc/batch_prefill.cu | 54 +++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index d8c41b00..27f6c400 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -30,8 +30,7 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( CHECK_DIM(1, qo_indptr); CHECK_DIM(1, workspace_buffer); - // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); + qo_indptr = qo_indptr.to(torch::kInt32); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); @@ -89,11 +88,10 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( CHECK_EQ(paged_kv_last_page_len.size(0), batch_size); CHECK_EQ(paged_kv_data.size(1), 2); CHECK_EQ(paged_kv_data.size(4), head_dim); - // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32); + qo_indptr = qo_indptr.to(torch::kInt32); + paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); + paged_kv_indices = paged_kv_indices.to(torch::kInt32); + paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); torch::Tensor o = torch::empty_like(q, q.options()); @@ -125,9 +123,9 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, - /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); @@ -154,15 +152,17 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { + torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(paged_kv_data); CHECK_INPUT(paged_kv_indptr); CHECK_INPUT(paged_kv_indices); CHECK_INPUT(paged_kv_last_page_len); + CHECK_INPUT(custom_mask); + CHECK_INPUT(qk_indptr); CHECK_DIM(3, q); // (nnz_qo, H_qo, D) CHECK_DIM(1, qo_indptr); // (B + 1,) // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND @@ -171,6 +171,8 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu CHECK_DIM(1, paged_kv_indptr); // (B + 1,) CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) CHECK_DIM(1, paged_kv_last_page_len); // (B,) + CHECK_DIM(1, custom_mask); // (nnz_qk,) + CHECK_DIM(1, qk_indptr); // (B + 1,) int64_t batch_size = qo_indptr.size(0) - 1; int64_t nnz_qo = q.size(0); int64_t num_qo_heads = q.size(1); @@ -189,11 +191,13 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu CHECK_EQ(paged_kv_last_page_len.size(0), batch_size); CHECK_EQ(paged_kv_data.size(1), 2); CHECK_EQ(paged_kv_data.size(4), head_dim); - // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32); + CHECK_EQ(qk_indptr.size(0), batch_size + 1); + qo_indptr = qo_indptr.to(torch::kInt32); + paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); + paged_kv_indices = paged_kv_indices.to(torch::kInt32); + paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32); + custom_mask = custom_mask.to(torch::kFloat32); + qk_indptr = qk_indptr.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); torch::Tensor o = torch::empty_like(q, q.options()); @@ -225,9 +229,10 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, - /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); @@ -261,8 +266,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( CHECK_DIM(1, qo_indptr); CHECK_DIM(1, workspace_buffer); - // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); + qo_indptr = qo_indptr.to(torch::kInt32); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); @@ -308,9 +312,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( CHECK_EQ(k.size(2), v.size(2)); CHECK_EQ(k.size(2), head_dim); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32); + qo_indptr = qo_indptr.to(torch::kInt32); + kv_indptr = kv_indptr.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); torch::Tensor o = torch::empty_like(q, q.options()); @@ -393,7 +396,6 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardW CHECK_EQ(k.size(2), v.size(2)); CHECK_EQ(k.size(2), head_dim); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - // TODO(Zihao): support dispatching to different index data types. qo_indptr = qo_indptr.to(torch::kInt32); kv_indptr = kv_indptr.to(torch::kInt32); qk_indptr = qk_indptr.to(torch::kInt32); From dc63f7d746021d82d9cdf750c74d1109e499439b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 02:25:18 +0000 Subject: [PATCH 11/14] bugfix --- python/csrc/batch_prefill.cu | 54 +++++++++++++++++------------------ python/csrc/single_prefill.cu | 4 +-- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 27f6c400..c6ed72af 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -205,7 +205,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu if (return_lse) { lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } - MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + constexpr MaskMode MASK_MODE = MaskMode::kCustom; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { @@ -217,33 +217,31 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu static_cast(paged_kv_last_page_len.data_ptr())); return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, - POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, - int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; - }); + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, KV_LAYOUT, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, + POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, + int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; }); - }); - }); + }); + }); }); }); }); @@ -365,7 +363,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( } } -std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardWithMask( +std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 4335666e..7b3f31ca 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -21,7 +21,7 @@ using namespace flashinfer; std::vector single_prefill_with_kv_cache( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, unsigned int mask_mode, + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); @@ -54,7 +54,7 @@ std::vector single_prefill_with_kv_cache( lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); } - MaskMode mask_mode = MaskMode(mask_mode_value); + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { From 0fcec034a6f936e21cf86c2b06cbd9d16e6dc6b5 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 06:14:10 +0000 Subject: [PATCH 12/14] bugfix --- include/flashinfer/attention/prefill.cuh | 2 +- python/csrc/flashinfer_ops.h | 2 +- python/csrc/single_prefill.cu | 5 +- python/flashinfer/prefill.py | 113 ++++++++++------ python/tests/test_batch_prefill_kernels.py | 144 +++++++++++++++++++++ 5 files changed, 225 insertions(+), 41 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index b0efe1bc..fb0a0d48 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -905,7 +905,7 @@ template __global__ void SinglePrefillWithKVCacheKernel( DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, - DTypeOut* __restrict__ o, float* __restrict__ custom_mask, void* __restrict__ tmp, + float* __restrict__ custom_mask, DTypeOut* __restrict__ o, void* __restrict__ tmp, float* __restrict__ lse, const tensor_info_t qkv_info, float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index f3cf242f..1dff6ba3 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -36,7 +36,7 @@ std::vector single_prefill_with_kv_cache( float sm_scale, float rope_scale, float rope_theta, bool return_lse); std::vector single_prefill_with_kv_cache_custom_mask( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, torch::Tensor custom_mask, + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor custom_mask, torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 7b3f31ca..162713a7 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -98,9 +98,8 @@ std::vector single_prefill_with_kv_cache( std::vector single_prefill_with_kv_cache_custom_mask( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor custom_mask, torch::Tensor tmp, - unsigned int mask_mode, unsigned int layout, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { + unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 0bd37c79..ff7eb2aa 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -140,7 +140,7 @@ def single_prefill_with_kv_cache( if rope_theta is None: rope_theta = 1e4 if custom_mask is not None: - return _kernels.single_prefill_with_kv_cache_custom_custom_mask( + return _kernels.single_prefill_with_kv_cache_custom_mask( q, k, v, @@ -274,7 +274,7 @@ def single_prefill_with_kv_cache_return_lse( k = k.to(torch.float16) v = v.to(torch.float16) if custom_mask is not None: - return _kernels.single_prefill_with_kv_cache_custom_custom_mask( + return _kernels.single_prefill_with_kv_cache_custom_mask( q, k, v, @@ -320,7 +320,7 @@ def _compute_page_qk_indptr( qk_indptr[1:] = torch.cumsum( (qo_indptr[1:] - qo_indptr[:-1]) * ( - (paged_kv_indptr[1:] - paged_kv_indptr[:-1]) * page_size + (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) * page_size + paged_kv_last_page_len ), 0, @@ -423,6 +423,8 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._paged_kv_indptr = None self._paged_kv_indices = None self._paged_kv_last_page_len = None + self._custom_mask = None + self._qk_indptr = None def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): r"""Reset the workspace buffer. @@ -487,6 +489,7 @@ def begin_forward( self._paged_kv_indices = paged_kv_indices self._paged_kv_last_page_len = paged_kv_last_page_len if custom_mask is not None: + self._custom_mask = custom_mask self._qk_indptr = _compute_page_qk_indptr( qo_indptr, paged_kv_indptr, @@ -508,6 +511,8 @@ def end_forward(self): self._paged_kv_indptr = None self._paged_kv_indices = None self._paged_kv_last_page_len = None + self._custom_mask = None + self._qk_indptr = None self._wrapper.end_forward() def forward( @@ -571,21 +576,39 @@ def forward( paged_kv_data = paged_kv_data.to(torch.float16) paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) - return self._wrapper.forward( - q, - self._qo_indptr, - paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - False, - )[0] + if self._custom_mask is None: + return self._wrapper.forward( + q, + self._qo_indptr, + paged_kv_data, + self._paged_kv_indptr, + self._paged_kv_indices, + self._paged_kv_last_page_len, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] + else: + return self._wrapper.forward_custom_mask( + q, + self._qo_indptr, + paged_kv_data, + self._paged_kv_indptr, + self._paged_kv_indices, + self._paged_kv_last_page_len, + self._custom_mask, + self._qk_indptr, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] def forward_return_lse( self, @@ -651,21 +674,39 @@ def forward_return_lse( paged_kv_data = paged_kv_data.to(torch.float16) paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) - return self._wrapper.forward( - q, - self._qo_indptr, - paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - True, - ) + if self._custom_mask is None: + return self._wrapper.forward( + q, + self._qo_indptr, + paged_kv_data, + self._paged_kv_indptr, + self._paged_kv_indices, + self._paged_kv_last_page_len, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) + else: + return self._wrapper.forward( + q, + self._qo_indptr, + paged_kv_data, + self._paged_kv_indptr, + self._paged_kv_indices, + self._paged_kv_last_page_len, + self._custom_mask, + self._qk_indptr, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) def _compute_qk_indptr(qo_indptr: torch.Tensor, kv_indptr: torch.Tensor): @@ -815,8 +856,8 @@ def begin_forward( self._qo_indptr = qo_indptr self._kv_indptr = kv_indptr if custom_mask is not None: - self._qk_indptr = _compute_qk_indptr(qo_indptr, kv_indptr) self._custom_mask = custom_mask + self._qk_indptr = _compute_qk_indptr(qo_indptr, kv_indptr) self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, @@ -911,7 +952,7 @@ def forward( False, )[0] else: - return self._wrapper.forward_custom_custom_mask( + return self._wrapper.forward_custom_mask( q, self._qo_indptr, k, @@ -1006,7 +1047,7 @@ def forward_return_lse( True, ) else: - return self._wrapper.forward_custom_custom_mask( + return self._wrapper.forward_custom_mask( q, self._qo_indptr, k, diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index 112cbea4..a7b0754e 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -72,6 +72,7 @@ def test_batch_prefill_with_paged_kv_cache( num_qo_heads, num_kv_heads, head_dim, + page_size, ) o = wrapper.forward(q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode) @@ -117,6 +118,90 @@ def test_batch_prefill_with_paged_kv_cache( numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("qo_len", [37, 17]) +@pytest.mark.parametrize("page_size", [1, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) +def test_batch_prefill_with_paged_kv_cache_custom_mask( + batch_size, + kv_len, + qo_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + kv_layout, + pos_encoding_mode, +): + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() + q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = ( + torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() + if kv_layout == "HND" + else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) + .to(0) + .half() + ) + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages).to(0).int() + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ).to(0) + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + custom_mask = ( + torch.triu( + torch.full((batch_size, qo_len, kv_len), -5e4, dtype=torch.float32), + diagonal=(kv_len - qo_len + 1), + ) + .reshape(-1) + .to(0) + ) + + # use custom mask + wrapper.begin_forward( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + custom_mask, + ) + o_custom = wrapper.forward(q, kv_data, pos_encoding_mode=pos_encoding_mode) + wrapper.end_forward() + + # use causal + wrapper.begin_forward( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + page_size, + head_dim, + ) + o_causal = wrapper.forward( + q, kv_data, causal=True, pos_encoding_mode=pos_encoding_mode + ) + numpy.testing.assert_allclose( + o_custom.cpu().numpy(), o_causal.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 97]) @pytest.mark.parametrize("qo_len", [37, 17]) @@ -169,6 +254,61 @@ def test_batch_prefill_with_ragged_kv_cache( numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("qo_len", [37, 17]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) +def test_batch_prefill_with_ragged_kv_cache_custom_mask( + batch_size, + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + head_dim, + pos_encoding_mode, +): + kv_layout = "NHD" + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() + q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + + k = torch.randn(batch_size * kv_len, num_kv_heads, head_dim).to(0).half() + v = torch.randn(batch_size * kv_len, num_kv_heads, head_dim).to(0).half() + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, kv_layout + ) + + custom_mask = ( + torch.triu( + torch.full((batch_size, qo_len, kv_len), -5e4, dtype=torch.float32), + diagonal=(kv_len - qo_len + 1), + ) + .reshape(-1) + .to(0) + ) + + # use custom mask + wrapper.begin_forward( + q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, custom_mask + ) + o_custom = wrapper.forward(q, k, v, pos_encoding_mode=pos_encoding_mode) + wrapper.end_forward() + + # use causal + wrapper.begin_forward(q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim) + o_causal = wrapper.forward( + q, k, v, causal=True, pos_encoding_mode=pos_encoding_mode + ) + numpy.testing.assert_allclose( + o_custom.cpu().numpy(), o_causal.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + if __name__ == "__main__": test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 8, 8, 8, 128, True, "HND", "NONE" @@ -176,4 +316,8 @@ def test_batch_prefill_with_ragged_kv_cache( test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE" ) + test_batch_prefill_with_paged_kv_cache_custom_mask( + 12, 137, 137, 1, 8, 8, 128, "HND", "NONE" + ) test_batch_prefill_with_ragged_kv_cache(12, 54, 37, 8, 8, 128, True, "NONE") + test_batch_prefill_with_ragged_kv_cache_custom_mask(12, 137, 137, 8, 8, 128, "NONE") From f21259680a26003b9d9eb04dc9ff98c256c4d328 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 06:21:57 +0000 Subject: [PATCH 13/14] slight bugfix --- python/flashinfer/cascade.py | 5 +++++ python/flashinfer/prefill.py | 7 ++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 52dd3161..d0504d95 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -583,6 +583,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: ... num_qo_heads, ... num_kv_heads, ... head_dim, + ... page_size, ... ) >>> outputs = [] >>> for i in range(num_layers): @@ -646,6 +647,7 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, head_dim: int, + page_size: int, ): r"""Create auxiliary data structures for shared-prefix batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -667,6 +669,8 @@ def begin_forward( The number of key/value heads. head_dim : int The dimension of the heads. + page_size : int + The page size of the paged kv-cache. Notes ----- @@ -687,6 +691,7 @@ def begin_forward( num_qo_heads, num_kv_heads, head_dim, + page_size, ) def end_forward(self): diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index ff7eb2aa..c05f4f53 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -375,7 +375,8 @@ class BatchPrefillWithPagedKVCacheWrapper: ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, - ... head_dim + ... head_dim, + ... page_size, ... ) >>> outputs = [] >>> for i in range(num_layers): @@ -469,6 +470,8 @@ def begin_forward( The number of key/value heads. head_dim : int The dimension of the heads. + page_size : int + The size of each page in the paged kv-cache. custom_mask : Optional[torch.Tensor] The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. The mask tensor will be applied to the attention matrix before softmax if provided. @@ -540,6 +543,8 @@ def forward( if :attr:`kv_layout` is ``HND``. causal : bool Whether to apply causal mask to the attention matrix. + This is only effective when :attr:`custom_mask` is not provided in + :meth:`begin_forward`. pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. From bd4e79cfc24d50a819ee8b298060895c9eafb3ef Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 28 May 2024 06:24:19 +0000 Subject: [PATCH 14/14] trailing empty line --- include/flashinfer/attention/mask.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/attention/mask.cuh b/include/flashinfer/attention/mask.cuh index 9adc6c26..771c2f47 100644 --- a/include/flashinfer/attention/mask.cuh +++ b/include/flashinfer/attention/mask.cuh @@ -26,4 +26,4 @@ enum class MaskMode { } // namespace flashinfer -#endif // FLASHINFER_ATTENTION_MASK_CUH_ \ No newline at end of file +#endif // FLASHINFER_ATTENTION_MASK_CUH_