diff --git a/CMakeLists.txt b/CMakeLists.txt index 38fcd7b7..fd50b4da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,7 +81,7 @@ set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS}) set (KV_LAYOUTS ${FLASHINFER_GEN_KV_LAYOUTS}) set (POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES}) set (ALLOW_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS}) -set (CAUSALS ${FLASHINFER_GEN_CASUALS}) +set (MASK_MODES ${FLASHINFER_GEN_MASK_MODES}) set (DECODE_DTYPES "f16") set (PREFILL_DTYPES "f16") set (DECODE_F8_DTYPES) @@ -104,14 +104,14 @@ message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}") message(STATUS "FLASHINFER_KV_LAYOUTS=${KV_LAYOUTS}") message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}") message(STATUS "FLASHINFER_ALLOW_FP16_QK_REDUCTIONS=${ALLOW_FP16_QK_REDUCTIONS}") -message(STATUS "FLASHINFER_CAUSALS=${CAUSALS}") +message(STATUS "FLASHINFER_MASK_MODES=${MASK_MODES}") file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated) set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc) add_custom_command( OUTPUT ${dispatch_inc_file} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --causals ${CAUSALS} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py COMMENT "Generating additional source file ${generated_dispatch_inc}" VERBATIM @@ -225,9 +225,9 @@ foreach(group_size IN LISTS GROUP_SIZES) foreach(kv_layout IN LISTS KV_LAYOUTS) foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(causal IN LISTS CAUSALS) + foreach(mask_mode IN LISTS MASK_MODES) foreach(dtype IN LISTS PREFILL_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} @@ -237,7 +237,7 @@ foreach(group_size IN LISTS GROUP_SIZES) ) list(APPEND single_prefill_kernels_src ${generated_kernel_src}) endforeach(dtype) - endforeach(causal) + endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) endforeach(kv_layout) @@ -251,10 +251,10 @@ foreach(group_size IN LISTS GROUP_SIZES) foreach(kv_layout IN LISTS KV_LAYOUTS) foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(causal IN LISTS CAUSALS) + foreach(mask_mode IN LISTS MASK_MODES) foreach(dtype IN LISTS PREFILL_DTYPES) foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} @@ -265,7 +265,7 @@ foreach(group_size IN LISTS GROUP_SIZES) list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) endforeach(idtype) endforeach(dtype) - endforeach(causal) + endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) endforeach(kv_layout) @@ -279,10 +279,10 @@ foreach(group_size IN LISTS GROUP_SIZES) foreach(kv_layout IN LISTS KV_LAYOUTS) foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(causal IN LISTS CAUSALS) + foreach(mask_mode IN LISTS MASK_MODES) foreach(dtype IN LISTS PREFILL_DTYPES) foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} @@ -293,7 +293,7 @@ foreach(group_size IN LISTS GROUP_SIZES) list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) endforeach(idtype) endforeach(dtype) - endforeach(causal) + endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) endforeach(kv_layout) diff --git a/cmake/config.cmake b/cmake/config.cmake index 10a1b843..c2fd48e9 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -24,7 +24,7 @@ set(FLASHINFER_GEN_HEAD_DIMS 64 128 256) set(FLASHINFER_GEN_KV_LAYOUTS 0 1) set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2) set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true") -set(FLASHINFER_GEN_CASUALS "false" "true") +set(FLASHINFER_GEN_MASK_MODES 0 1) # Set target cuda architectures for tests/benchmarks, defaults to native. # "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU. diff --git a/include/flashinfer/attention/mask.cuh b/include/flashinfer/attention/mask.cuh new file mode 100644 index 00000000..771c2f47 --- /dev/null +++ b/include/flashinfer/attention/mask.cuh @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_MASK_CUH_ +#define FLASHINFER_ATTENTION_MASK_CUH_ + +namespace flashinfer { + +enum class MaskMode { + kNone = 0U, // No mask + kCausal = 1U, // Causal mask + kCustom = 2U, // Custom mask +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_MASK_CUH_ diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 90edae5e..fb0a0d48 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -36,6 +36,7 @@ #include "../utils.cuh" #include "cascade.cuh" #include "handler.cuh" +#include "mask.cuh" #include "state.cuh" namespace flashinfer { @@ -550,11 +551,11 @@ __device__ __forceinline__ void apply_alibi_bias(const uint32_t qo_idx_base, } } -template __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, - const uint32_t chunk_end, + const uint32_t chunk_end, float* custom_mask, DTypeQKAccum (*s_frag)[num_frags_z][8]) { const uint32_t tx = threadIdx.x; #pragma unroll @@ -568,9 +569,15 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_ kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + 8 * (reg_id / 4) + reg_id % 2; const bool out_of_boundary = - (causal ? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end)) - : kv_idx >= chunk_end); - s_frag[fx][fz][reg_id] = out_of_boundary ? DTypeQKAccum(-5e4) : s_frag[fx][fz][reg_id]; + (mask_mode == MaskMode::kCausal + ? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + s_frag[fx][fz][reg_id] = + out_of_boundary + ? DTypeQKAccum(-5e4) + : s_frag[fx][fz][reg_id] + DTypeQKAccum(mask_mode == MaskMode::kCustom + ? custom_mask[q_idx * kv_len + kv_idx] + : 0.f); } } } @@ -870,7 +877,7 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * \brief FlashAttention prefill CUDA kernel for a single request. * \tparam partition_kv Whether to split kv_len into chunks. * \tparam group_size The number of qo heads that maps to a kv head (used in GQA). - * \tparam causal Whether to use causal attention. + * \tparam mask_mode The mask mode used in the attention operation. * \tparam kv_layout The layout of the input tensor. * \tparam pos_encoding_mode The positional encoding mode. * \tparam num_frags_x The number of fragments in x dimension. @@ -892,15 +899,15 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * \param log2_rope_rcp_theta log2(1/(rope_theta)), where rope_theta is the theta * used in RoPE. */ -template __global__ void SinglePrefillWithKVCacheKernel( DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, - DTypeOut* __restrict__ o, void* __restrict__ tmp, float* __restrict__ lse, - const tensor_info_t qkv_info, float sm_scale, - const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { + float* __restrict__ custom_mask, DTypeOut* __restrict__ o, void* __restrict__ tmp, + float* __restrict__ lse, const tensor_info_t qkv_info, + float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= math::log2e; @@ -983,19 +990,21 @@ __global__ void SinglePrefillWithKVCacheKernel( v_smem(smem + (num_warps * num_frags_x + num_frags_z) * 16 * head_dim * sizeof(DTypeIn)); const uint32_t num_iterations = ceil_div( - causal ? min(chunk_end - chunk_start, - sub_if_greater_or_zero( - kv_len - qo_len + ((bx + 1) * num_frags_x * num_warps * 16) / group_size, - chunk_start)) - : chunk_end - chunk_start, + mask_mode == MaskMode::kCausal + ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len - qo_len + ((bx + 1) * num_frags_x * num_warps * 16) / group_size, + chunk_start)) + : chunk_end - chunk_start, 16 * num_frags_z); const uint32_t mask_iteration = - (causal ? min(chunk_end - chunk_start, - sub_if_greater_or_zero( - kv_len + (bx * num_warps * num_frags_x * 16) / group_size - qo_len, - chunk_start)) - : (chunk_end - chunk_start)) / + (mask_mode == MaskMode::kCausal + ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len + (bx * num_warps * num_frags_x * 16) / group_size - qo_len, + chunk_start)) + : (chunk_end - chunk_start)) / (16 * num_frags_z); DTypeIn* k_ptr = k + qkv_info.get_kv_elem_offset(chunk_start + ty * 4 + tx / 8, kv_head_idx, @@ -1035,9 +1044,16 @@ __global__ void SinglePrefillWithKVCacheKernel( alibi_slopes, s_frag); } // apply mask - if (iter >= mask_iteration) { - mask_s( - qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, s_frag); + if constexpr (mask_mode == MaskMode::kCustom) { + mask_s( + qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, + custom_mask, s_frag); + } else { + if (iter >= mask_iteration) { + mask_s(qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, + chunk_end, nullptr, s_frag); + } } // compute m,d states in online softmax @@ -1097,13 +1113,15 @@ __global__ void SinglePrefillWithKVCacheKernel( } } -template +template __global__ void BatchPrefillWithRaggedKVCacheKernel( DTypeIn* __restrict__ q, IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k, - DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, IdType* __restrict__ q_offset, + DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, float* __restrict__ custom_mask, + IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, float* __restrict__ lse, uint32_t batch_size, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { @@ -1193,14 +1211,15 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } - const uint32_t num_iterations = ceil_div( - (causal ? min(kv_len, - kv_len - qo_len + ((tile_idx + 1) * num_frags_x * num_warps * 16) / group_size) - : kv_len), - 16 * num_frags_z); + const uint32_t num_iterations = + ceil_div((mask_mode == MaskMode::kCausal + ? min(kv_len, kv_len - qo_len + + ((tile_idx + 1) * num_frags_x * num_warps * 16) / group_size) + : kv_len), + 16 * num_frags_z); const uint32_t mask_iteration = - (causal + (mask_mode == MaskMode::kCausal ? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / group_size - qo_len, kv_len) : kv_len) / (16 * num_frags_z); @@ -1251,9 +1270,16 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( qo_idx_base, iter * 16 * num_frags_z, int(kv_len) - int(qo_len), alibi_slopes, s_frag); } // apply mask - if (iter >= mask_iteration) { - mask_s( - qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag); + if constexpr (mask_mode == MaskMode::kCustom) { + mask_s( + qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, + custom_mask + qk_indptr[request_idx], s_frag); + } else { + if (iter >= mask_iteration) { + mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, + s_frag); + } } // compute m,d states in online softmax @@ -1304,16 +1330,16 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } } -template +template __global__ void BatchPrefillWithPagedKVCacheKernel( IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, DTypeIn* __restrict__ q, paged_kv_t paged_kv, - IdType* __restrict__ qo_indptr, IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, - float* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float log2_rope_rcp_scale, - float log2_rope_rcp_theta) { + IdType* __restrict__ qo_indptr, float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, + IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, + float* __restrict__ lse, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { constexpr uint32_t rows_per_warp = 16 / (2 * group_size) * 2; constexpr uint32_t aligned_group_size = 16 / rows_per_warp; static_assert(sizeof(DTypeIn) == 2); @@ -1420,14 +1446,14 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( cp_async::commit_group(); const uint32_t num_iterations = ceil_div( - (causal + (mask_mode == MaskMode::kCausal ? min(kv_len, kv_len - qo_len + ((tile_idx + 1) * num_frags_x * num_warps * 16) / aligned_group_size) : kv_len), 16 * num_frags_z); const uint32_t mask_iteration = - (causal + (mask_mode == MaskMode::kCausal ? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / aligned_group_size - qo_len, kv_len) : kv_len) / @@ -1456,9 +1482,16 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( qo_idx_base, iter * 16 * num_frags_z, int(kv_len) - int(qo_len), alibi_slopes, s_frag); } // apply mask - if (iter >= mask_iteration) { - mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag); + if constexpr (mask_mode == MaskMode::kCustom) { + mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, + custom_mask + qk_indptr[request_idx], s_frag); + } else { + if (iter >= mask_iteration) { + mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, + s_frag); + } } // compute m,d states in online softmax @@ -1523,7 +1556,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( * \param qo_len The length of query and output. * \param kv_len The length of key and value. * \param head_dim The dimension of each head. - * \param causal Whether to use causal attention. + * \param mask_mode The mask mode applied in the attention score. * \param kv_layout The layout of KV Cache. * \param pos_encoding_mode The positional encoding mode. * \param allow_fp16_qk_reduction Whether to allow accumulating q*k^T with fp16. @@ -1533,13 +1566,13 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( template cudaError_t SinglePrefillWithKVCacheWorkEstimation( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, bool causal = true, + uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, MaskMode mask_mode, QKVLayout kv_layout = QKVLayout::kNHD, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool allow_fp16_qk_reduction = false, cudaStream_t stream = nullptr) { - if (kv_len < qo_len && causal) { + if (kv_len < qo_len && mask_mode == MaskMode::kCausal) { std::ostringstream err_msg; - err_msg << "When causal is true, kv_len must be greater than or equal to qo_len, " + err_msg << "When setting mask_mode to kCausal, kv_len must be greater than or equal to qo_len, " << "got kv_len " << kv_len << " and qo_len " << qo_len; throw std::invalid_argument(err_msg.str()); } @@ -1551,8 +1584,8 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( (qo_len * group_size > 64 && head_dim < 256 ? 2 : 1), num_frags_x, {DISPATCH_GQA_GROUP_SIZE( group_size, GROUP_SIZE, - {DISPATCH_CAUSAL( - causal, CAUSAL, {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + {DISPATCH_MASK_MODE( + mask_mode, MASK_MODE, {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t num_frags_y = HEAD_DIM / 16; DISPATCH_POS_ENCODING_MODE( pos_encoding_mode, pos_encoding_mode, @@ -1605,7 +1638,7 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; auto partition_kv_kernel = SinglePrefillWithKVCacheKernel< - /*partition_kv=*/true, GROUP_SIZE, CAUSAL, KV_LAYOUT, + /*partition_kv=*/true, GROUP_SIZE, MASK_MODE, KV_LAYOUT, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; tensor_info_t qkv_info( @@ -1657,18 +1690,19 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( } template -cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, - float* tmp, float* lse, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, float sm_scale, - float rope_scale, float rope_theta, - cudaStream_t stream) { +cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, + float* custom_mask, DTypeOut* o, float* tmp, + float* lse, uint32_t num_kv_heads, uint32_t qo_len, + uint32_t kv_len, float sm_scale, float rope_scale, + float rope_theta, cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); - if (kv_len < qo_len && CAUSAL) { + if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { std::ostringstream err_msg; - err_msg << "When causal is true, kv_len must be greater than or equal to qo_len, got kv_len" + err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " + "to qo_len, got kv_len" << kv_len << " and qo_len " << qo_len; throw std::invalid_argument(err_msg.str()); } @@ -1713,7 +1747,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* constexpr uint32_t num_threads = num_warps * warp_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; auto partition_kv_kernel = - SinglePrefillWithKVCacheKernel; tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); @@ -1741,11 +1775,12 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv auto kernel = SinglePrefillWithKVCacheKernel< - /*partition_kv=*/false, GROUP_SIZE, CAUSAL, KV_LAYOUT, pos_encoding_mode, num_frags_x, - num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; + /*partition_kv=*/false, GROUP_SIZE, MASK_MODE, KV_LAYOUT, pos_encoding_mode, + num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; void* args[] = {(void*)&q, (void*)&k, (void*)&v, + (void*)&custom_mask, (void*)&o, (void*)&tmp, (void*)&lse, @@ -1764,6 +1799,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* void* args[] = {(void*)&q, (void*)&k, (void*)&v, + (void*)&custom_mask, (void*)&o, (void*)&tmp, (void*)&lse, @@ -1788,14 +1824,14 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* } template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, - float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles, - const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, - const float rope_theta, cudaStream_t stream = nullptr) { + DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, + IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, + const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float sm_scale, + const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; @@ -1836,7 +1872,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( throw std::invalid_argument(err_msg.str()); } else { auto kernel = - BatchPrefillWithRaggedKVCacheKernel; uint32_t smem_size = @@ -1850,6 +1886,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&k, (void*)&v, (void*)&kv_indptr, + (void*)&custom_mask, + (void*)&qk_indptr, (void*)&q_offset, (void*)&k_rope_pos_offset, (void*)&o, @@ -1867,13 +1905,13 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* tmp, - float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { + paged_kv_t paged_kv, float* custom_mask, + IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; @@ -1917,8 +1955,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( throw std::invalid_argument(err_msg.str()); } else { auto kernel = BatchPrefillWithPagedKVCacheKernel< - GROUP_SIZE, PAGE_SIZE, CAUSAL, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, - num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; + GROUP_SIZE, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, + num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); FLASHINFER_CUDA_CALL( @@ -1928,6 +1966,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (void*)&q, (void*)&paged_kv, (void*)&qo_indptr, + (void*)&custom_mask, + (void*)&qk_indptr, (void*)&q_offset, (void*)&o, (void*)&tmp, diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 3e41dc7d..a59777c0 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -21,6 +21,7 @@ #include #include "attention/handler.cuh" +#include "attention/mask.cuh" #include "layout.cuh" #include "page.cuh" #include "pos_enc.cuh" @@ -29,40 +30,42 @@ namespace flashinfer { template -cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, - float* tmp, float* lse, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, float sm_scale, - float rope_scale, float rope_theta, - cudaStream_t stream); +cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, + float* custom_mask, DTypeOut* o, float* tmp, + float* lse, uint32_t num_kv_heads, uint32_t qo_len, + uint32_t kv_len, float sm_scale, float rope_scale, + float rope_theta, cudaStream_t stream); template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, - float* tmp, float* lse, uint32_t batch_size, uint32_t num_qo_tiles, uint32_t num_kv_heads, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream = nullptr); + DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, + IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size, + uint32_t num_qo_tiles, uint32_t num_kv_heads, float sm_scale, float rope_scale, + float rope_theta, cudaStream_t stream = nullptr); template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* tmp, - float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream); + paged_kv_t paged_kv, float* custom_mask, + IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream); template + MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { + paged_kv_t paged_kv, float* custom_mask, + IdType* qk_indptr, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -83,21 +86,21 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithPagedKVCacheDispatched< page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, pos_encoding_mode, - ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, o, tmp, lse, num_qo_tiles, - sm_scale, rope_scale, rope_theta, stream); + ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( + q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o, + tmp, lse, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, - uint32_t batch_size, uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { + IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, + IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_kv_heads, + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -118,10 +121,10 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithRaggedKVCacheDispatched( - q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_offset, k_rope_pos_offset, - o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta, - stream); + MASK_MODE, DTypeIn, DTypeOut, IdType>( + q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr, + q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, + rope_scale, rope_theta, stream); }); return cudaSuccess; } diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index f4511e56..358674a6 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -137,13 +137,28 @@ throw std::invalid_argument(err_msg.str()); \ } -#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \ - if (causal) { \ - constexpr bool CAUSAL = true; \ - __VA_ARGS__ \ - } else { \ - constexpr bool CAUSAL = false; \ - __VA_ARGS__ \ +#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ + switch (mask_mode) { \ + case MaskMode::kNone: { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCausal: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCustom: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported mask_mode: " << int(mask_mode); \ + throw std::invalid_argument(err_msg.str()); \ + } \ } #define DISPATCH_LAYOUT(layout, LAYOUT, ...) \ diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 47a0df17..c6ed72af 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -30,8 +30,7 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( CHECK_DIM(1, qo_indptr); CHECK_DIM(1, workspace_buffer); - // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); + qo_indptr = qo_indptr.to(torch::kInt32); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); @@ -89,11 +88,10 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( CHECK_EQ(paged_kv_last_page_len.size(0), batch_size); CHECK_EQ(paged_kv_data.size(1), 2); CHECK_EQ(paged_kv_data.size(4), head_dim); - // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32); + qo_indptr = qo_indptr.to(torch::kInt32); + paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); + paged_kv_indices = paged_kv_indices.to(torch::kInt32); + paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); torch::Tensor o = torch::empty_like(q, q.options()); @@ -101,6 +99,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( if (return_lse) { lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } + MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { @@ -112,7 +111,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( static_cast(paged_kv_last_page_len.data_ptr())); return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_causal(causal, CAUSAL, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( @@ -120,11 +119,13 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< PageStorage::kIndices, KV_LAYOUT, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, - POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, + POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); @@ -148,6 +149,111 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( } } +std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask( + torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { + CHECK_INPUT(q); + CHECK_INPUT(qo_indptr); + CHECK_INPUT(paged_kv_data); + CHECK_INPUT(paged_kv_indptr); + CHECK_INPUT(paged_kv_indices); + CHECK_INPUT(paged_kv_last_page_len); + CHECK_INPUT(custom_mask); + CHECK_INPUT(qk_indptr); + CHECK_DIM(3, q); // (nnz_qo, H_qo, D) + CHECK_DIM(1, qo_indptr); // (B + 1,) + // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND + // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND + CHECK_DIM(5, paged_kv_data); + CHECK_DIM(1, paged_kv_indptr); // (B + 1,) + CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) + CHECK_DIM(1, paged_kv_last_page_len); // (B,) + CHECK_DIM(1, custom_mask); // (nnz_qk,) + CHECK_DIM(1, qk_indptr); // (B + 1,) + int64_t batch_size = qo_indptr.size(0) - 1; + int64_t nnz_qo = q.size(0); + int64_t num_qo_heads = q.size(1); + int64_t head_dim = q.size(2); + int64_t num_kv_heads, page_size; + if (kv_layout_ == QKVLayout::kHND) { + num_kv_heads = paged_kv_data.size(2); + page_size = paged_kv_data.size(3); + } else { + page_size = paged_kv_data.size(2); + num_kv_heads = paged_kv_data.size(3); + } + CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); + CHECK_EQ(qo_indptr.size(0), batch_size + 1); + CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1); + CHECK_EQ(paged_kv_last_page_len.size(0), batch_size); + CHECK_EQ(paged_kv_data.size(1), 2); + CHECK_EQ(paged_kv_data.size(4), head_dim); + CHECK_EQ(qk_indptr.size(0), batch_size + 1); + qo_indptr = qo_indptr.to(torch::kInt32); + paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); + paged_kv_indices = paged_kv_indices.to(torch::kInt32); + paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32); + custom_mask = custom_mask.to(torch::kFloat32); + qk_indptr = qk_indptr.to(torch::kInt32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + torch::Tensor o = torch::empty_like(q, q.options()); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); + } + constexpr MaskMode MASK_MODE = MaskMode::kCustom; + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, KV_LAYOUT, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, + POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, + int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} + void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) { @@ -158,8 +264,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( CHECK_DIM(1, qo_indptr); CHECK_DIM(1, workspace_buffer); - // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); + qo_indptr = qo_indptr.to(torch::kInt32); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); @@ -205,9 +310,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( CHECK_EQ(k.size(2), v.size(2)); CHECK_EQ(k.size(2), head_dim); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - // TODO(Zihao): support dispatching to different index data types. - CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32); + qo_indptr = qo_indptr.to(torch::kInt32); + kv_indptr = kv_indptr.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); torch::Tensor o = torch::empty_like(q, q.options()); @@ -216,10 +320,12 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } + MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_causal(causal, CAUSAL, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( @@ -227,11 +333,12 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, int32_t>( + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(kv_indptr.data_ptr()), + /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, @@ -255,3 +362,87 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( return {o}; } } + +std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask( + torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, + torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { + CHECK_INPUT(q); + CHECK_INPUT(qo_indptr); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(custom_mask); + CHECK_INPUT(qk_indptr); + CHECK_DIM(3, q); // (nnz_qo, H_qo, D) + CHECK_DIM(1, qo_indptr); // (B + 1,) + CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) + CHECK_DIM(3, v); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) + CHECK_DIM(1, kv_indptr); // (B + 1,) + CHECK_DIM(1, custom_mask); // (nnz_qk,) + CHECK_DIM(1, qk_indptr); // (B + 1,) + int64_t batch_size = qo_indptr.size(0) - 1; + int64_t nnz_qo = q.size(0); + int64_t num_qo_heads = q.size(1); + int64_t head_dim = q.size(2); + CHECK_EQ(kv_indptr.size(0), batch_size + 1); + CHECK_EQ(qk_indptr.size(0), batch_size + 1); + int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0); + CHECK_EQ(k.size(0), v.size(0)); + CHECK_EQ(k.size(1), v.size(1)); + CHECK_EQ(k.size(2), v.size(2)); + CHECK_EQ(k.size(2), head_dim); + CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); + qo_indptr = qo_indptr.to(torch::kInt32); + kv_indptr = kv_indptr.to(torch::kInt32); + qk_indptr = qk_indptr.to(torch::kInt32); + custom_mask = custom_mask.to(torch::kFloat32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + torch::Tensor o = torch::empty_like(q, q.options()); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); + } + + constexpr MaskMode MASK_MODE = MaskMode::kCustom; + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< + GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + static_cast(k.data_ptr()), static_cast(v.data_ptr()), + static_cast(kv_indptr.data_ptr()), + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 05118bd1..2653d913 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -22,6 +22,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Single-request decode with KV-Cache operator"); m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, "Single-request prefill with KV-Cache operator, return logsumexp"); + m.def( + "single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask, + "Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp"); m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); m.def("merge_state", &merge_state, "Merge two self-attention states"); m.def("merge_state_in_place", &merge_state_in_place, @@ -62,7 +65,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward) .def("update_page_locked_buffer_size", &BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward); + .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward) + .def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask); py::class_( m, "BatchPrefillWithRaggedKVCachePyTorchWrapper") .def(py::init()) @@ -70,5 +74,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward) .def("update_page_locked_buffer_size", &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward); + .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) + .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index d826d71f..1dff6ba3 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -35,6 +35,11 @@ std::vector single_prefill_with_kv_cache( unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); +std::vector single_prefill_with_kv_cache_custom_mask( + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor custom_mask, torch::Tensor tmp, + unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + float sm_scale, float rope_scale, float rope_theta, bool return_lse); + void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, torch::Tensor append_indptr, torch::Tensor kv_data, torch::Tensor kv_indices, torch::Tensor kv_indptr, @@ -121,6 +126,12 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + std::vector ForwardCustomMask( + torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, unsigned int max_workspace_size_in_bytes) : kv_layout_(flashinfer::QKVLayout(layout)), @@ -143,6 +154,13 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + std::vector ForwardCustomMask(torch::Tensor q, torch::Tensor qo_indptr, + torch::Tensor k, torch::Tensor v, + torch::Tensor kv_indptr, torch::Tensor custom_mask, + torch::Tensor qk_indptr, + unsigned int pos_encoding_mode, + bool allow_fp16_qk_reduction, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, unsigned int max_workspace_size_in_bytes) : kv_layout_(flashinfer::QKVLayout(layout)), diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 21c93899..8d5a6952 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -134,8 +134,8 @@ using namespace flashinfer; _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) -#define DISPATCH_causal(expr, const_expr, ...) \ - _DISPATCH_SWITCH("causal", expr, _DISPATCH_CASES_causal(const_expr, __VA_ARGS__)) +#define DISPATCH_mask_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) inline void check_shape(const torch::Tensor& a, const torch::Tensor& b, const char* a_name, const char* b_name) { diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index e5212170..162713a7 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -54,24 +54,29 @@ std::vector single_prefill_with_kv_cache( lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); } + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_causal(causal, CAUSAL, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { return DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SinglePrefillWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, CAUSAL>( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, - torch_current_stream); + cudaError_t status = + SinglePrefillWithKVCacheDispatched( + static_cast(q.data_ptr()), + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + /*custom_mask=*/nullptr, static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, + torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); @@ -90,3 +95,79 @@ std::vector single_prefill_with_kv_cache( return {o}; } } + +std::vector single_prefill_with_kv_cache_custom_mask( + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor custom_mask, torch::Tensor tmp, + unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + CHECK_INPUT(q); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(custom_mask); + CHECK_DIM(3, q); + CHECK_DIM(3, k); + CHECK_DIM(3, v); + CHECK_DIM(2, custom_mask); + CHECK_SHAPE(k, v); + CHECK_EQ(q.size(2), k.size(2)); + unsigned int head_dim = q.size(2); + unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; + QKVLayout kv_layout = static_cast(layout); + if (kv_layout == QKVLayout::kNHD) { + kv_len = k.size(0); + qo_len = q.size(0); + num_kv_heads = k.size(1); + num_qo_heads = q.size(1); + } else { + kv_len = k.size(1); + qo_len = q.size(1); + num_kv_heads = k.size(0); + num_qo_heads = q.size(0); + } + CHECK_EQ(custom_mask.size(0), qo_len); + CHECK_EQ(custom_mask.size(1), kv_len); + CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + auto o = torch::empty_like(q, q.options()); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); + } + + constexpr MaskMode MASK_MODE = MaskMode::kCustom; + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = SinglePrefillWithKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE>( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + static_cast(custom_mask.data_ptr()), + static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); + }); + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 52dd3161..d0504d95 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -583,6 +583,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: ... num_qo_heads, ... num_kv_heads, ... head_dim, + ... page_size, ... ) >>> outputs = [] >>> for i in range(num_layers): @@ -646,6 +647,7 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, head_dim: int, + page_size: int, ): r"""Create auxiliary data structures for shared-prefix batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -667,6 +669,8 @@ def begin_forward( The number of key/value heads. head_dim : int The dimension of the heads. + page_size : int + The page size of the paged kv-cache. Notes ----- @@ -687,6 +691,7 @@ def begin_forward( num_qo_heads, num_kv_heads, head_dim, + page_size, ) def end_forward(self): diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index f024a1ac..c05f4f53 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -57,6 +57,7 @@ def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + custom_mask: Optional[torch.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", @@ -80,8 +81,11 @@ def single_prefill_with_kv_cache( The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is ``HND``. + custom_mask : Optional[torch.Tensor] + The custom mask tensor, shape: ``[qo_len, kv_len]``. causal : bool Whether to apply causal mask to the attention matrix. + This is only effective when :attr:`custom_mask` is not provided. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str @@ -135,26 +139,43 @@ def single_prefill_with_kv_cache( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 - return _kernels.single_prefill_with_kv_cache( - q, - k, - v, - tmp, - causal, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - False, - )[0] + if custom_mask is not None: + return _kernels.single_prefill_with_kv_cache_custom_mask( + q, + k, + v, + tmp, + custom_mask, + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] + else: + return _kernels.single_prefill_with_kv_cache( + q, + k, + v, + tmp, + causal, + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] def single_prefill_with_kv_cache_return_lse( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + custom_mask: Optional[torch.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", @@ -178,8 +199,11 @@ def single_prefill_with_kv_cache_return_lse( The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is ``HND``. + custom_mask : Optional[torch.Tensor] + The custom_mask tensor, shape: ``[qo_len, kv_len]``. causal : bool Whether to apply causal mask to the attention matrix. + This is only effective when :attr:`custom_mask` is not provided. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str @@ -249,20 +273,59 @@ def single_prefill_with_kv_cache_return_lse( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - return _kernels.single_prefill_with_kv_cache( - q, - k, - v, - tmp, - causal, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - True, + if custom_mask is not None: + return _kernels.single_prefill_with_kv_cache_custom_mask( + q, + k, + v, + tmp, + custom_mask, + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) + else: + return _kernels.single_prefill_with_kv_cache( + q, + k, + v, + tmp, + causal, + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) + + +def _compute_page_qk_indptr( + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + page_size: int, +): + if len(qo_indptr) != len(paged_kv_indptr): + raise ValueError( + "The length of qo_indptr and paged_kv_indptr should be the same." + ) + qk_indptr = torch.empty_like(qo_indptr) + qk_indptr[0] = 0 + qk_indptr[1:] = torch.cumsum( + (qo_indptr[1:] - qo_indptr[:-1]) + * ( + (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) * page_size + + paged_kv_last_page_len + ), + 0, ) + return qk_indptr class BatchPrefillWithPagedKVCacheWrapper: @@ -312,7 +375,8 @@ class BatchPrefillWithPagedKVCacheWrapper: ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, - ... head_dim + ... head_dim, + ... page_size, ... ) >>> outputs = [] >>> for i in range(num_layers): @@ -360,6 +424,8 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._paged_kv_indptr = None self._paged_kv_indices = None self._paged_kv_last_page_len = None + self._custom_mask = None + self._qk_indptr = None def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): r"""Reset the workspace buffer. @@ -381,6 +447,8 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, head_dim: int, + page_size: int, + custom_mask: Optional[torch.Tensor] = None, ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -402,6 +470,11 @@ def begin_forward( The number of key/value heads. head_dim : int The dimension of the heads. + page_size : int + The size of each page in the paged kv-cache. + custom_mask : Optional[torch.Tensor] + The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. + The mask tensor will be applied to the attention matrix before softmax if provided. Notes ----- @@ -418,6 +491,14 @@ def begin_forward( self._paged_kv_indptr = paged_kv_indptr self._paged_kv_indices = paged_kv_indices self._paged_kv_last_page_len = paged_kv_last_page_len + if custom_mask is not None: + self._custom_mask = custom_mask + self._qk_indptr = _compute_page_qk_indptr( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + page_size, + ) self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, @@ -433,6 +514,8 @@ def end_forward(self): self._paged_kv_indptr = None self._paged_kv_indices = None self._paged_kv_last_page_len = None + self._custom_mask = None + self._qk_indptr = None self._wrapper.end_forward() def forward( @@ -460,6 +543,8 @@ def forward( if :attr:`kv_layout` is ``HND``. causal : bool Whether to apply causal mask to the attention matrix. + This is only effective when :attr:`custom_mask` is not provided in + :meth:`begin_forward`. pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. @@ -496,21 +581,39 @@ def forward( paged_kv_data = paged_kv_data.to(torch.float16) paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) - return self._wrapper.forward( - q, - self._qo_indptr, - paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - False, - )[0] + if self._custom_mask is None: + return self._wrapper.forward( + q, + self._qo_indptr, + paged_kv_data, + self._paged_kv_indptr, + self._paged_kv_indices, + self._paged_kv_last_page_len, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] + else: + return self._wrapper.forward_custom_mask( + q, + self._qo_indptr, + paged_kv_data, + self._paged_kv_indptr, + self._paged_kv_indices, + self._paged_kv_last_page_len, + self._custom_mask, + self._qk_indptr, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] def forward_return_lse( self, @@ -576,21 +679,51 @@ def forward_return_lse( paged_kv_data = paged_kv_data.to(torch.float16) paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) - return self._wrapper.forward( - q, - self._qo_indptr, - paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - True, - ) + if self._custom_mask is None: + return self._wrapper.forward( + q, + self._qo_indptr, + paged_kv_data, + self._paged_kv_indptr, + self._paged_kv_indices, + self._paged_kv_last_page_len, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) + else: + return self._wrapper.forward( + q, + self._qo_indptr, + paged_kv_data, + self._paged_kv_indptr, + self._paged_kv_indices, + self._paged_kv_last_page_len, + self._custom_mask, + self._qk_indptr, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) + + +def _compute_qk_indptr(qo_indptr: torch.Tensor, kv_indptr: torch.Tensor): + if len(qo_indptr) != len(kv_indptr): + raise ValueError("The length of qo_indptr and kv_indptr should be the same.") + qk_indptr = torch.empty_like(qo_indptr) + qk_indptr[0] = 0 + qk_indptr[1:] = torch.cumsum( + (qo_indptr[1:] - qo_indptr[:-1]) * (kv_indptr[1:] - kv_indptr[:-1]), + 0, + ) + return qk_indptr class BatchPrefillWithRaggedKVCacheWrapper: @@ -672,6 +805,8 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): ) self._qo_indptr = None self._kv_indptr = None + self._custom_mask = None + self._qk_indptr = None def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): r"""Reset the workspace buffer. @@ -691,6 +826,7 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, head_dim: int, + custom_mask: Optional[torch.Tensor] = None, ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -707,6 +843,9 @@ def begin_forward( The number of key/value heads. head_dim : int The dimension of the heads. + custom_mask : Optional[torch.Tensor] + The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. + The mask tensor will be added to the attention matrix before softmax. Notes ----- @@ -721,6 +860,9 @@ def begin_forward( batch_size = len(qo_indptr) - 1 self._qo_indptr = qo_indptr self._kv_indptr = kv_indptr + if custom_mask is not None: + self._custom_mask = custom_mask + self._qk_indptr = _compute_qk_indptr(qo_indptr, kv_indptr) self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, @@ -734,6 +876,8 @@ def end_forward(self): r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" self._qo_indptr = None self._kv_indptr = None + self._custom_mask = None + self._qk_indptr = None self._wrapper.end_forward() def forward( @@ -761,6 +905,7 @@ def forward( The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` causal : bool Whether to apply causal mask to the attention matrix. + This argument is ignored if ``mask`` is provided in :meth:`begin_forward`. pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. @@ -796,20 +941,37 @@ def forward( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - return self._wrapper.forward( - q, - self._qo_indptr, - k, - v, - self._kv_indptr, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - False, - )[0] + if self._custom_mask is None: + return self._wrapper.forward( + q, + self._qo_indptr, + k, + v, + self._kv_indptr, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] + else: + return self._wrapper.forward_custom_mask( + q, + self._qo_indptr, + k, + v, + self._kv_indptr, + self._custom_mask, + self._qk_indptr, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] def forward_return_lse( self, @@ -836,6 +998,7 @@ def forward_return_lse( The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` causal : bool Whether to apply causal mask to the attention matrix. + This argument is ignored if ``mask`` is provided in :meth:`begin_forward`. pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. @@ -873,17 +1036,34 @@ def forward_return_lse( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - return self._wrapper.forward( - q, - self._qo_indptr, - k, - v, - self._kv_indptr, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - sm_scale, - rope_scale, - rope_theta, - True, - ) + if self._custom_mask is None: + return self._wrapper.forward( + q, + self._qo_indptr, + k, + v, + self._kv_indptr, + causal, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) + else: + return self._wrapper.forward_custom_mask( + q, + self._qo_indptr, + k, + v, + self._kv_indptr, + self._custom_mask, + self._qk_indptr, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + sm_scale, + rope_scale, + rope_theta, + True, + ) diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index 3a061491..5d301dd2 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -18,6 +18,7 @@ import re import itertools from literal_map import ( + mask_mode_literal, kv_layout_literal, pos_encoding_mode_literal, dtype_literal, @@ -33,7 +34,7 @@ def get_cu_file_str( kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype_in, dtype_out, idtype, @@ -41,10 +42,11 @@ def get_cu_file_str( num_frags_x_choices = [1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {idtype}* q_offset, paged_kv_t paged_kv, + float* custom_mask, {idtype}* qk_indptr, {dtype_out}* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, @@ -57,7 +59,7 @@ def get_cu_file_str( head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, - causal=causal, + mask_mode=mask_mode_literal[int(mask_mode)], dtype_in=dtype_literal[dtype_in], dtype_out=dtype_literal[dtype_out], idtype=idtype_literal[idtype], @@ -81,7 +83,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_paged_prefill_group_([0-9]+)_page_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_causal_([a-z]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) path = Path(sys.argv[1]) diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index 65f862f4..7eeab91e 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -17,6 +17,7 @@ import sys import re from literal_map import ( + mask_mode_literal, kv_layout_literal, pos_encoding_mode_literal, dtype_literal, @@ -31,7 +32,7 @@ def get_cu_file_str( kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype_in, dtype_out, idtype, @@ -39,9 +40,11 @@ def get_cu_file_str( num_frags_x_choices = [1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {causal}, {dtype_in}, {dtype_out}, {idtype}>( + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, - {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, {idtype}* q_offset, {idtype}* k_rope_pos_offset, + {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, + float* custom_mask, {idtype}* qk_indptr, + {idtype}* q_offset, {idtype}* k_rope_pos_offset, {dtype_out}* o, float* tmp, float* lse, uint32_t batch_size, uint32_t num_qo_tiles, uint32_t num_kv_heads, float sm_scale, float rope_scale, @@ -53,7 +56,7 @@ def get_cu_file_str( head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, - causal=causal, + mask_mode=mask_mode_literal[int(mask_mode)], dtype_in=dtype_literal[dtype_in], dtype_out=dtype_literal[dtype_out], idtype=idtype_literal[idtype], @@ -76,7 +79,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_ragged_prefill_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_causal_([a-z]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) path = Path(sys.argv[1]) diff --git a/python/generate_dispatch_inc.py b/python/generate_dispatch_inc.py index 1a2b104b..03ec819d 100644 --- a/python/generate_dispatch_inc.py +++ b/python/generate_dispatch_inc.py @@ -16,7 +16,12 @@ import argparse from pathlib import Path -from literal_map import kv_layout_literal, pos_encoding_mode_literal, bool_literal +from literal_map import ( + kv_layout_literal, + pos_encoding_mode_literal, + bool_literal, + mask_mode_literal, +) def get_dispatch_inc_str(args: argparse.Namespace) -> str: @@ -90,15 +95,17 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: {dispatch_allow_fp16_qk_reduction_entries} // EOL """ - # causal - dispatch_causal_entries = "\n".join( + # mask_mode + dispatch_mask_mode_entries = "\n".join( [ - " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(bool_literal[_]) - for _ in args.causals + " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format( + mask_mode_literal[_] + ) + for _ in args.mask_modes ] ) - dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(case_var, ...) \\ -{dispatch_causal_entries} + dispatch_mask_mode_str = f"""#define _DISPATCH_CASES_mask_mode(case_var, ...) \\ +{dispatch_mask_mode_entries} // EOL """ @@ -110,7 +117,7 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: dispatch_kv_layouts_str, dispatch_pos_encoding_modes_str, dispatch_allow_fp16_qk_reductions_str, - dispatch_causal_str, + dispatch_mask_mode_str, ] ) @@ -151,11 +158,11 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: help="Allow fp16 qk reductions", ) parser.add_argument( - "--causals", - type=lambda x: x if isinstance(x, int) else x.lower() == "true", + "--mask_modes", + type=int, required=True, nargs="+", - help="Causals", + help="Mask modes", ) args = parser.parse_args() print(args) diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index d3018375..7ffad989 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -16,7 +16,12 @@ import sys import re -from literal_map import kv_layout_literal, pos_encoding_mode_literal, dtype_literal +from literal_map import ( + kv_layout_literal, + pos_encoding_mode_literal, + dtype_literal, + mask_mode_literal, +) from pathlib import Path @@ -26,7 +31,7 @@ def get_cu_file_str( kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype_in, dtype_out, ): @@ -35,8 +40,8 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t SinglePrefillWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {causal}, {dtype_in}, {dtype_out}>( - {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, {dtype_out}* o, +template cudaError_t SinglePrefillWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( + {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, float* custom_mask, {dtype_out}* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); @@ -48,7 +53,7 @@ def get_cu_file_str( head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, - causal=causal, + mask_mode=mask_mode_literal[int(mask_mode)], dtype_in=dtype_literal[dtype_in], dtype_out=dtype_literal[dtype_out], ) @@ -58,7 +63,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"single_prefill_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_causal_([a-z]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/literal_map.py b/python/literal_map.py index 3587572e..bf4ac679 100644 --- a/python/literal_map.py +++ b/python/literal_map.py @@ -14,6 +14,12 @@ limitations under the License. """ +mask_mode_literal = { + 0: "MaskMode::kNone", + 1: "MaskMode::kCausal", + 2: "MaskMode::kCustom", +} + kv_layout_literal = { 0: "QKVLayout::kNHD", 1: "QKVLayout::kHND", diff --git a/python/setup.py b/python/setup.py index 3eec45df..fd524f63 100644 --- a/python/setup.py +++ b/python/setup.py @@ -73,7 +73,7 @@ def get_instantiation_cu() -> List[str]: allow_fp16_qk_reduction_options = os.environ.get( "FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0" ).split(",") - causal_options = os.environ.get("FLASHINFER_CAUSAL_OPTIONS", "0,1").split(",") + mask_modes = os.environ.get("FLASHINFER_MASK_MODES", "0,1,2").split(",") # dispatch.inc path = root / prefix / "dispatch.inc" write_if_different( @@ -86,7 +86,7 @@ def get_instantiation_cu() -> List[str]: kv_layouts=map(int, kv_layouts), pos_encoding_modes=map(int, pos_encoding_modes), allow_fp16_qk_reductions=map(int, allow_fp16_qk_reduction_options), - causals=map(int, causal_options), + mask_modes=map(int, mask_modes), ) ), ) @@ -217,17 +217,17 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, ) in itertools.product( group_sizes, head_dims, kv_layouts, pos_encoding_modes, allow_fp16_qk_reduction_options, - causal_options, + mask_modes, ): for dtype in prefill_dtypes: - fname = f"single_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_causal_{causal}_dtypein_{dtype}_dtypeout_{dtype}.cu" + fname = f"single_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" files.append(prefix + "/" + fname) content = generate_single_prefill_inst.get_cu_file_str( group_size, @@ -235,7 +235,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype, dtype, ) @@ -249,7 +249,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, idtype, ) in itertools.product( group_sizes, @@ -258,11 +258,11 @@ def get_instantiation_cu() -> List[str]: kv_layouts, pos_encoding_modes, allow_fp16_qk_reduction_options, - causal_options, + mask_modes, idtypes, ): for dtype in prefill_dtypes: - fname = f"batch_paged_prefill_group_{group_size}_page_{page_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_causal_{causal}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + fname = f"batch_paged_prefill_group_{group_size}_page_{page_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_prefill_inst.get_cu_file_str( group_size, @@ -271,7 +271,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype, dtype, idtype, @@ -285,7 +285,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, idtype, ) in itertools.product( group_sizes, @@ -293,11 +293,11 @@ def get_instantiation_cu() -> List[str]: kv_layouts, pos_encoding_modes, allow_fp16_qk_reduction_options, - causal_options, + mask_modes, idtypes, ): for dtype in prefill_dtypes: - fname = f"batch_ragged_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_causal_{causal}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + fname = f"batch_ragged_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_ragged_prefill_inst.get_cu_file_str( group_size, @@ -305,7 +305,7 @@ def get_instantiation_cu() -> List[str]: kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, - causal, + mask_mode, dtype, dtype, idtype, diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index 112cbea4..a7b0754e 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -72,6 +72,7 @@ def test_batch_prefill_with_paged_kv_cache( num_qo_heads, num_kv_heads, head_dim, + page_size, ) o = wrapper.forward(q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode) @@ -117,6 +118,90 @@ def test_batch_prefill_with_paged_kv_cache( numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("qo_len", [37, 17]) +@pytest.mark.parametrize("page_size", [1, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) +def test_batch_prefill_with_paged_kv_cache_custom_mask( + batch_size, + kv_len, + qo_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + kv_layout, + pos_encoding_mode, +): + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() + q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = ( + torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() + if kv_layout == "HND" + else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) + .to(0) + .half() + ) + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages).to(0).int() + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ).to(0) + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + custom_mask = ( + torch.triu( + torch.full((batch_size, qo_len, kv_len), -5e4, dtype=torch.float32), + diagonal=(kv_len - qo_len + 1), + ) + .reshape(-1) + .to(0) + ) + + # use custom mask + wrapper.begin_forward( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + custom_mask, + ) + o_custom = wrapper.forward(q, kv_data, pos_encoding_mode=pos_encoding_mode) + wrapper.end_forward() + + # use causal + wrapper.begin_forward( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + page_size, + head_dim, + ) + o_causal = wrapper.forward( + q, kv_data, causal=True, pos_encoding_mode=pos_encoding_mode + ) + numpy.testing.assert_allclose( + o_custom.cpu().numpy(), o_causal.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 97]) @pytest.mark.parametrize("qo_len", [37, 17]) @@ -169,6 +254,61 @@ def test_batch_prefill_with_ragged_kv_cache( numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("qo_len", [37, 17]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) +def test_batch_prefill_with_ragged_kv_cache_custom_mask( + batch_size, + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + head_dim, + pos_encoding_mode, +): + kv_layout = "NHD" + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() + q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + + k = torch.randn(batch_size * kv_len, num_kv_heads, head_dim).to(0).half() + v = torch.randn(batch_size * kv_len, num_kv_heads, head_dim).to(0).half() + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, kv_layout + ) + + custom_mask = ( + torch.triu( + torch.full((batch_size, qo_len, kv_len), -5e4, dtype=torch.float32), + diagonal=(kv_len - qo_len + 1), + ) + .reshape(-1) + .to(0) + ) + + # use custom mask + wrapper.begin_forward( + q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, custom_mask + ) + o_custom = wrapper.forward(q, k, v, pos_encoding_mode=pos_encoding_mode) + wrapper.end_forward() + + # use causal + wrapper.begin_forward(q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim) + o_causal = wrapper.forward( + q, k, v, causal=True, pos_encoding_mode=pos_encoding_mode + ) + numpy.testing.assert_allclose( + o_custom.cpu().numpy(), o_causal.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + if __name__ == "__main__": test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 8, 8, 8, 128, True, "HND", "NONE" @@ -176,4 +316,8 @@ def test_batch_prefill_with_ragged_kv_cache( test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE" ) + test_batch_prefill_with_paged_kv_cache_custom_mask( + 12, 137, 137, 1, 8, 8, 128, "HND", "NONE" + ) test_batch_prefill_with_ragged_kv_cache(12, 54, 37, 8, 8, 128, True, "NONE") + test_batch_prefill_with_ragged_kv_cache_custom_mask(12, 137, 137, 8, 8, 128, "NONE") diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 1924f20e..dc48fb90 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -56,21 +56,23 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu cudaStream_t stream = nullptr) { const uint32_t group_size = num_qo_heads / num_kv_heads; const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {DISPATCH_group_size( group_size, GROUP_SIZE, - {DISPATCH_causal( - causal, CAUSAL, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, {DISPATCH_head_dim(head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode( pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { return SinglePrefillWithKVCacheDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, CAUSAL>( - q, k, v, o, tmp, lse, num_kv_heads, qo_len, kv_len, - sm_scale, rope_scale, rope_theta, stream); + ALLOW_FP16_QK_REDUCTION, MASK_MODE>( + q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, + num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, + rope_theta, stream); })})})})})}); return cudaSuccess; } @@ -85,24 +87,25 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_kv_layout( kv_layout, KV_LAYOUT, {DISPATCH_group_size( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_head_dim( head_dim, HEAD_DIM, - {DISPATCH_causal( - causal, CAUSAL, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, {DISPATCH_pos_encoding_mode( pos_encoding_mode, pos_encoding_mode, {DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { return BatchPrefillWithRaggedKVCacheWrapperDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, pos_encoding_mode, - ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, q_offset, k_rope_pos_offset, - o, lse, batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, - stream); + ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, + batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); })})})})})}); return cudaSuccess; } @@ -119,23 +122,25 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_group_size( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_head_dim( head_dim, HEAD_DIM, - {DISPATCH_causal(causal, CAUSAL, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, pos_encoding_mode, - {DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, - {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { - return BatchPrefillWithPagedKVCacheWrapperDispatched< - page_storage, kv_layout, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, - pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, - DTypeIn, DTypeOut, IdType>(handler, q, qo_indptr, q_offset, - paged_kv, o, lse, sm_scale, - rope_scale, rope_theta, stream); - })})})})})}); + {DISPATCH_mask_mode(mask_mode, MASK_MODE, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, pos_encoding_mode, + {DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, + {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { + return BatchPrefillWithPagedKVCacheWrapperDispatched< + page_storage, kv_layout, PAGE_SIZE, GROUP_SIZE, + HEAD_DIM, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, + MASK_MODE, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, q_offset, paged_kv, + /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, o, lse, + sm_scale, rope_scale, rope_theta, stream); + })})})})})}); return cudaSuccess; } diff --git a/src/utils.h b/src/utils.h index b1570450..d77ebecb 100644 --- a/src/utils.h +++ b/src/utils.h @@ -68,8 +68,8 @@ _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) -#define DISPATCH_causal(expr, const_expr, ...) \ - _DISPATCH_SWITCH("causal", expr, _DISPATCH_CASES_causal(const_expr, __VA_ARGS__)) +#define DISPATCH_mask_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) namespace utils {