Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support custom attention mask in prefill/append attention kernels #266

Merged
merged 14 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS})
set (KV_LAYOUTS ${FLASHINFER_GEN_KV_LAYOUTS})
set (POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES})
set (ALLOW_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS})
set (CAUSALS ${FLASHINFER_GEN_CASUALS})
set (MASK_MODES ${FLASHINFER_GEN_MASK_MODES})
set (DECODE_DTYPES "f16")
set (PREFILL_DTYPES "f16")
set (DECODE_F8_DTYPES)
Expand All @@ -104,14 +104,14 @@ message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}")
message(STATUS "FLASHINFER_KV_LAYOUTS=${KV_LAYOUTS}")
message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}")
message(STATUS "FLASHINFER_ALLOW_FP16_QK_REDUCTIONS=${ALLOW_FP16_QK_REDUCTIONS}")
message(STATUS "FLASHINFER_CAUSALS=${CAUSALS}")
message(STATUS "FLASHINFER_MASK_MODES=${MASK_MODES}")

file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated)

set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc)
add_custom_command(
OUTPUT ${dispatch_inc_file}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --causals ${CAUSALS}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py
COMMENT "Generating additional source file ${generated_dispatch_inc}"
VERBATIM
Expand Down Expand Up @@ -225,9 +225,9 @@ foreach(group_size IN LISTS GROUP_SIZES)
foreach(kv_layout IN LISTS KV_LAYOUTS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(causal IN LISTS CAUSALS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}.cu)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src}
Expand All @@ -237,7 +237,7 @@ foreach(group_size IN LISTS GROUP_SIZES)
)
list(APPEND single_prefill_kernels_src ${generated_kernel_src})
endforeach(dtype)
endforeach(causal)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
endforeach(kv_layout)
Expand All @@ -251,10 +251,10 @@ foreach(group_size IN LISTS GROUP_SIZES)
foreach(kv_layout IN LISTS KV_LAYOUTS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(causal IN LISTS CAUSALS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
Expand All @@ -265,7 +265,7 @@ foreach(group_size IN LISTS GROUP_SIZES)
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(causal)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
endforeach(kv_layout)
Expand All @@ -279,10 +279,10 @@ foreach(group_size IN LISTS GROUP_SIZES)
foreach(kv_layout IN LISTS KV_LAYOUTS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(causal IN LISTS CAUSALS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
Expand All @@ -293,7 +293,7 @@ foreach(group_size IN LISTS GROUP_SIZES)
list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(causal)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
endforeach(kv_layout)
Expand Down
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
set(FLASHINFER_GEN_CASUALS "false" "true")
set(FLASHINFER_GEN_MASK_MODES 0 1)

# Set target cuda architectures for tests/benchmarks, defaults to native.
# "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU.
Expand Down
29 changes: 29 additions & 0 deletions include/flashinfer/attention/mask.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_ATTENTION_MASK_CUH_
#define FLASHINFER_ATTENTION_MASK_CUH_

namespace flashinfer {

enum class MaskMode {
kNone = 0U, // No mask
kCausal = 1U, // Causal mask
kCustom = 2U, // Custom mask
};

} // namespace flashinfer

#endif // FLASHINFER_ATTENTION_MASK_CUH_
Loading