diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config
index 7bf8181b1f838..96bb053a13f29 100644
--- a/.pipelines/nuget_config/x64/packages.config
+++ b/.pipelines/nuget_config/x64/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config
index 30f7862a11078..6bf842ac18037 100644
--- a/.pipelines/nuget_config/x86/packages.config
+++ b/.pipelines/nuget_config/x86/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake
index 54e361ffdb3ae..8b5f602643c0b 100644
--- a/cmake/external/dml.cmake
+++ b/cmake/external/dml.cmake
@@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
- set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.15.0)
+ set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.15.1)
# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake
index bdb4b00b02a35..927b4ac84b037 100644
--- a/cmake/onnxruntime.cmake
+++ b/cmake/onnxruntime.cmake
@@ -38,10 +38,14 @@ function(get_c_cxx_api_headers HEADERS_VAR)
# need to add header files for enabled EPs
foreach(f ${ONNXRUNTIME_PROVIDER_NAMES})
- file(GLOB _provider_headers CONFIGURE_DEPENDS
- "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h"
- )
- list(APPEND _headers ${_provider_headers})
+ # The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory
+ # with onnxruntime_c_api.h . Most other EPs probably also do not work in this way.
+ if((NOT f STREQUAL cuda) AND (NOT f STREQUAL rocm))
+ file(GLOB _provider_headers CONFIGURE_DEPENDS
+ "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h"
+ )
+ list(APPEND _headers ${_provider_headers})
+ endif()
endforeach()
set(${HEADERS_VAR} ${_headers} PARENT_SCOPE)
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 66f4aea606ef5..c02ac2096db2e 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -555,8 +555,17 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
- set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
+message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
+message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}")
+
+if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10")
+ message(STATUS "Using -mavx2 -mfma -mavxvnni flags")
+ set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni")
+else()
+ message(STATUS "Using -mavx2 -mfma flags")
+ set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
+endif()
set(mlas_platform_srcs_avx512f
${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S
${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S
@@ -575,7 +584,7 @@ else()
${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
)
- set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl")
+ set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl")
set(mlas_platform_srcs_avx512vnni
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake
index d2afe19f36691..bbcc709b144a0 100644
--- a/cmake/onnxruntime_providers_cpu.cmake
+++ b/cmake/onnxruntime_providers_cpu.cmake
@@ -219,6 +219,7 @@ if (onnxruntime_ENABLE_TRAINING)
endif()
install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/cpu/cpu_provider_factory.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/)
+install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/resource.h ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/custom_op_context.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers)
set_target_properties(onnxruntime_providers PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers PROPERTIES FOLDER "ONNXRuntime")
diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake
index 82c31ce6b6b4d..0829be05a3ab0 100644
--- a/cmake/onnxruntime_providers_cuda.cmake
+++ b/cmake/onnxruntime_providers_cuda.cmake
@@ -289,8 +289,15 @@
config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj)
endif()
config_cuda_provider_shared_module(onnxruntime_providers_cuda)
-
+ # Cannot use glob because the file cuda_provider_options.h should not be exposed out.
+ set(ONNXRUNTIME_CUDA_PROVIDER_PUBLIC_HEADERS
+ "${REPO_ROOT}/include/onnxruntime/core/providers/cuda/cuda_context.h"
+ "${REPO_ROOT}/include/onnxruntime/core/providers/cuda/cuda_resource.h"
+ )
+ set_target_properties(onnxruntime_providers_cuda PROPERTIES
+ PUBLIC_HEADER "${ONNXRUNTIME_CUDA_PROVIDER_PUBLIC_HEADERS}")
install(TARGETS onnxruntime_providers_cuda
+ PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers/cuda
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake
index 71692ddb9391f..559204bd0df88 100644
--- a/cmake/onnxruntime_providers_rocm.cmake
+++ b/cmake/onnxruntime_providers_rocm.cmake
@@ -223,8 +223,13 @@
if (onnxruntime_ENABLE_ATEN)
target_compile_definitions(onnxruntime_providers_rocm PRIVATE ENABLE_ATEN)
endif()
-
+ file(GLOB ONNXRUNTIME_ROCM_PROVIDER_PUBLIC_HEADERS CONFIGURE_DEPENDS
+ "${REPO_ROOT}/include/onnxruntime/core/providers/rocm/*.h"
+ )
+ set_target_properties(onnxruntime_providers_rocm PROPERTIES
+ PUBLIC_HEADER "${ONNXRUNTIME_ROCM_PROVIDER_PUBLIC_HEADERS}")
install(TARGETS onnxruntime_providers_rocm
+ PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers/rocm
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/targets/net8.0-ios/targets.xml b/csharp/src/Microsoft.ML.OnnxRuntime/targets/net8.0-ios/targets.xml
index 3eb9720af511f..c6dbba8dfda76 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/targets/net8.0-ios/targets.xml
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/targets/net8.0-ios/targets.xml
@@ -1,7 +1,7 @@
-
+
Static
True
True
@@ -10,4 +10,4 @@
CoreML
-
\ No newline at end of file
+
diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h
index 0bb5c7432f0a7..6cff153c336f0 100644
--- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h
+++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h
@@ -3,12 +3,15 @@
#pragma once
+#include
#include
+#include
#include
#include
#include "core/common/inlined_containers.h"
#include "core/framework/session_options.h"
+#include "core/framework/tensor.h"
#include "core/optimizer/graph_transformer.h"
#include "core/platform/threadpool.h"
@@ -51,7 +54,8 @@ InlinedVector> GenerateTransformers(
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const InlinedHashSet& rules_and_transformers_to_disable = {},
- concurrency::ThreadPool* intra_op_thread_pool = nullptr);
+ concurrency::ThreadPool* intra_op_thread_pool = nullptr,
+ std::unordered_map>* p_buffered_tensors = nullptr);
#endif // !defined(ORT_MINIMAL_BUILD)
@@ -81,7 +85,8 @@ InlinedVector> GenerateTransformersForMinimalB
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet& rules_and_transformers_to_disable = {},
- concurrency::ThreadPool* intra_op_thread_pool = nullptr);
+ concurrency::ThreadPool* intra_op_thread_pool = nullptr,
+ std::unordered_map>* p_buffered_tensors = nullptr);
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
index 816eaaf9bc71a..ec9be80a63574 100644
--- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
+++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
@@ -19,7 +19,7 @@ struct OrtTensorRTProviderOptionsV2 {
// can be updated using: UpdateTensorRTProviderOptionsWithValue
int trt_max_partition_iterations{1000}; // maximum iterations for TensorRT parser to get capability
int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs
- size_t trt_max_workspace_size{1 << 30}; // maximum workspace size for TensorRT.
+ size_t trt_max_workspace_size{0}; // maximum workspace size for TensorRT. Default is 0 means max device memory size
int trt_fp16_enable{0}; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
int trt_int8_enable{0}; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
const char* trt_int8_calibration_table_name{nullptr}; // TensorRT INT8 calibration table name.
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
index 515a967aa2386..f7d8fedc734e4 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
@@ -258,7 +258,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias;
- output_parameters->pass_past_in_kv = false;
output_parameters->qkv_format = Q_K_V_BNSH;
}
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index 55292b35e1e38..88127387d08ea 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -6,6 +6,12 @@
namespace onnxruntime {
namespace contrib {
+enum AttentionType {
+ kAttention,
+ kMultiHeadAttention,
+ kDecoderMaskedMultiHeadAttention,
+};
+
enum AttentionMaskType {
MASK_NONE, // No mask
MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length
@@ -24,10 +30,12 @@ enum AttentionQkvFormat {
UNKNOWN, // enum value not set, or depends on qkv projection implementation details
Q_K_V_BNSH, // for non-packed qkv, permuted
Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention
- QKV_BSN3H, // for TRT fused attention, qkv are packed
+ Q_K_V_BSNH_BNSH_BNSH, // for cross attention, k and v are permuted
Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
- Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed.
+ Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
+ QKV_BSN3H, // for TRT fused attention, qkv are packed
+ QKV_BS3NH, // for DecoderMaskedMultiHeadAttention, qkv are packed
QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed
};
@@ -61,7 +69,6 @@ struct AttentionParameters {
bool past_present_share_buffer;
bool do_rotary;
bool broadcast_res_pos_bias;
- bool pass_past_in_kv;
float mask_filter_value;
float scale;
bool use_tf32;
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index 9677c30f22d8a..0d77376779230 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -85,7 +85,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
scale_,
is_unidirectional_,
past_present_share_buffer,
- false));
+ kMultiHeadAttention));
const int batch_size = parameters.batch_size;
const int q_sequence_length = parameters.sequence_length;
@@ -121,20 +121,13 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
- // For each of Q/K/V, there are multiple scenarios:
- // 1) Combined QKV bias is null
- // a) Q/K/V is (B, S, D)
- // b) Q/K/V is (B, S, N, H)
- // 2) No packed QKV in Q
- // a) Q/K/V has seq_len = 1
- // b) Q/K/V has seq_len > 1
-
OrtValue Q;
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias(
context, allocator, batch_size, num_heads_, q_sequence_length, qk_head_size, query, bias, q_bias_offset, Q));
- if (parameters.pass_past_in_kv) { // key and value in BNSH format
- assert(bias == nullptr);
+ if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) {
+ // For cross attention with k and v in BNSH format, we assume that bias for key and value are zeros.
+ // So we don't need to add bias for key and value here.
assert(past_key == nullptr);
assert(past_value == nullptr);
return ApplyAttention(Q.GetMutable()->MutableData(),
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index bd7ab09659170..cfb8d36843777 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -11,6 +11,232 @@ namespace onnxruntime {
namespace contrib {
namespace multihead_attention_helper {
+template
+Status Check_QKV(const T* packed_qkv, AttentionQkvFormat& qkv_format) {
+ const auto& query_dims = packed_qkv->Shape().GetDims();
+ if (query_dims.size() == 3) {
+ // Packed qkv used by DecoderMaskedMultiHeadAttention. Query shape is (B, S, 3D), no key and value.
+ qkv_format = AttentionQkvFormat::QKV_BS3NH;
+ } else {
+ assert(query_dims.size() == 5);
+ if (static_cast(query_dims[3]) != 3) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Expect 'query' shape (batch_size, sequence_length, num_heads, 3, head_size) for packed qkv");
+ }
+
+ qkv_format = AttentionQkvFormat::QKV_BSN3H;
+ }
+
+ return Status::OK();
+}
+
+template
+Status Check_Q_KV(const T* query, const T* packed_kv, int num_heads, int head_size,
+ AttentionQkvFormat& qkv_format, int& kv_sequence_length) {
+ const auto& query_dims = query->Shape().GetDims();
+ const auto& key_dims = packed_kv->Shape().GetDims();
+ if (query_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Expect rank of query be 3 for packed kv");
+ }
+
+ if (key_dims.size() != 5) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Expect rank of key be 5 for packed kv");
+ }
+
+ if (key_dims[0] != query_dims[0] ||
+ static_cast(key_dims[2]) != num_heads ||
+ static_cast(key_dims[3]) != 2 ||
+ static_cast(key_dims[4]) != head_size) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
+ }
+
+ qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
+ kv_sequence_length = static_cast(key_dims[1]);
+ return Status::OK();
+}
+
+template
+Status Check_Q_K_V(const T* query, const T* key, const T* value, int num_heads, int head_size,
+ AttentionQkvFormat& qkv_format, int& kv_sequence_length, int& v_hidden_size) {
+ const auto& query_dims = query->Shape().GetDims();
+ const auto& key_dims = key->Shape().GetDims();
+ const auto& value_dims = value->Shape().GetDims();
+ if (query_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Expect rank of query be 3 for packed kv");
+ }
+
+ if (key_dims.size() != value_dims.size() || (key_dims.size() != 3 && value_dims.size() != 4)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Expect rank of key and value be same, and either 3 or 4");
+ }
+
+ if (key_dims[0] != query_dims[0] || value_dims[0] != query_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query', 'key' and 'value' shall have same dim 0 (batch_size)");
+ }
+
+ if (key_dims.size() == 3) {
+ if (key_dims[2] != query_dims[2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 2 (hidden_size)");
+ }
+
+ if (key_dims[1] != value_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall have same dim 1 (kv_sequence_length)");
+ }
+
+ qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ kv_sequence_length = static_cast(key_dims[1]);
+ v_hidden_size = static_cast(value_dims[2]);
+ } else { // key_dims.size() == 4
+ if (value->Shape() != key->Shape() ||
+ static_cast(key_dims[1]) != num_heads ||
+ static_cast(key_dims[3]) != head_size) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall have same shape (batch_size, num_heads, kv_sequence_length, head_size)");
+ }
+
+ qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
+ kv_sequence_length = static_cast(key_dims[2]);
+ v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]);
+ }
+
+ return Status::OK();
+}
+
+template
+Status CheckPast(const T* past_key, const T* past_value, const T* past_seq_len,
+ int batch_size, int num_heads, int head_size, bool past_present_share_buffer,
+ int& past_sequence_length, int& max_sequence_length) {
+ const auto& past_key_dims = past_key->Shape().GetDims();
+ const auto& past_value_dims = past_value->Shape().GetDims();
+
+ if (past_key_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' is expected to have 4 dimensions, got ",
+ past_key_dims.size());
+ }
+ if (past_value_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' is expected to have 4 dimensions, got ",
+ past_value_dims.size());
+ }
+
+ if (past_key_dims[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' dimension 0 should be batch_size, got ",
+ past_key_dims[0]);
+ }
+ if (past_value_dims[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' dimension 0 should be batch_size, got ",
+ past_value_dims[0]);
+ }
+
+ if (past_key_dims[1] != num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' dimension 1 should be same as number of heads, got ",
+ past_key_dims[1]);
+ }
+ if (past_value_dims[1] != num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' dimension 1 should be same as number of heads, got ",
+ past_value_dims[1]);
+ }
+ if (past_key_dims[2] != past_value_dims[2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ",
+ past_key_dims[2], " vs ", past_value_dims[2]);
+ }
+ if (past_key_dims[3] != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' dimension 3 should be same as head_size, got ",
+ past_key_dims[3]);
+ }
+ if (past_value_dims[3] != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' dimension 3 should be same as head_size, got ",
+ past_value_dims[3]);
+ }
+ past_sequence_length = static_cast(past_key_dims[2]);
+ if (past_present_share_buffer) {
+ max_sequence_length = static_cast(past_key_dims[2]);
+ if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "past_sequence_length tensor must be of one element when past_present_share_buffer is set");
+ }
+ past_sequence_length = *((*past_seq_len).template Data());
+ }
+ return Status::OK();
+}
+
+template
+Status CheckRelativePositionBias(
+ const T* relative_position_bias, int batch_size, int num_heads, int sequence_length, int total_sequence_length,
+ bool& broadcast_res_pos_bias) {
+ const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
+
+ if (relative_position_bias_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' is expected to have 4 dimensions, got ",
+ relative_position_bias_dims.size());
+ }
+ if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ",
+ relative_position_bias_dims[0]);
+ }
+ if (relative_position_bias_dims[0] == 1) {
+ broadcast_res_pos_bias = true;
+ }
+ if (relative_position_bias_dims[1] != num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ",
+ relative_position_bias_dims[1]);
+ }
+ if (relative_position_bias_dims[2] != sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ",
+ relative_position_bias_dims[2]);
+ }
+ if (relative_position_bias_dims[3] != total_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ",
+ relative_position_bias_dims[3]);
+ }
+ return Status::OK();
+}
+
+template
+AttentionMaskType GetMaskType(const T* key_padding_mask, int batch_size, int sequence_length, int total_sequence_length) {
+ AttentionMaskType mask_type = AttentionMaskType::MASK_UNKNOWN;
+ const auto& mask_dims = key_padding_mask->Shape().GetDims();
+ if (mask_dims.size() == 1) {
+ if (mask_dims[0] == static_cast(batch_size)) {
+ mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
+ } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) {
+ mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
+ }
+ } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) &&
+ mask_dims[1] == static_cast(total_sequence_length)) {
+ mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
+ } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) &&
+ mask_dims[1] == static_cast(sequence_length) &&
+ mask_dims[2] == static_cast(total_sequence_length)) {
+ mask_type = AttentionMaskType::MASK_3D_ATTENTION;
+ }
+ return mask_type;
+}
+
template
Status CheckInputs(const T* query,
const T* key,
@@ -27,176 +253,128 @@ Status CheckInputs(const T* query,
float scale,
bool is_unidirectional,
bool past_present_share_buffer,
- bool dmmha_packing) {
- // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
- // relative_position_bias : (B, 1, S, L)
- // past_key : (B, N, S*, H)
- // past_value : (B, N, S*, H)
- // When no packing for q/k/v:
+ AttentionType operator_type) {
+ // ---------------------------------------------------------------
+ // Notations:
+ // B: batch_size
+ // N: num_heads
+ // H: head_size (V might have different head size than Q and K)
+ // D: hidden_size = N * H
+ // S: q_sequence_length
+ // P: past_sequence_length
+ // L: kv_sequence_length
+ // T: total_sequence_length = P + L
+ // M: max_sequence_length
+ // ---------------------------------------------------------------
+ // MultiHeadAttention inputs:
+ // ---------------------------------------------------------------
+ // Q_K_V_BSNH - no packing:
// query (Q) : (B, S, D)
- // key (K) : (B, L, D) or (B, N, S*, H)
- // value (V) : (B, L, D_v) or (B, N, S*, H)
- // bias (Q/K/V) : (D + D + D_v)
- // When packed kv is used:
+ // key (K) : (B, L, D)
+ // value (V) : (B, L, D_v)
+ // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v):
// query (Q) : (B, S, D)
- // key (K) : (B, L, N, 2, H)
- // value (V) : None
- // bias (Q/K/V) : None
- // When packed qkv is used:
- // query (Q) : (B, L, N, 3, H) or (B, S, 3*D)
+ // key (K) : (B, N, L, H)
+ // value (V) : (B, N, L, H)
+ // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv):
+ // query (Q) : (B, S, D)
+ // key (K/V) : (B, L, N, 2, H)
+ // value : None
+ // QKV_BSN3H - packed qkv (kv cache is not used, S == L, D == D_v):
+ // query (Q/K/V) : (B, S, N, 3, H)
+ // key : None
+ // value : None
+ //
+ // Other inputs:
+ // bias (Q/K/V) : None or (D + D + D_v)
+ // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T)
+ // relative_position_bias : (B, N, S, T) or (1, N, S, T)
+ // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH.
+ // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH.
+ // ---------------------------------------------------------------
+ // DecoderMaskedMultiHeadAttention inputs (S == 1, D == D_v):
+ // ---------------------------------------------------------------
+ // Q_K_V_BSNH - no packing:
+ // query (Q) : (B, S, D)
+ // key (K) : (B, L, D)
+ // value (V) : (B, L, D)
+ // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and relative_position_bias are not used. L == T):
+ // query (Q) : (B, S, D)
+ // key (K) : (B, N, L, H)
+ // value (V) : (B, N, L, H)
+ // QKV_BS3NH - packed qkv (S == L):
+ // query (Q) : (B, S, 3 * D)
// key (K) : None
// value (V) : None
- // bias (Q/K/V) : None or (D + D + D_v)
-
- AttentionQkvFormat qkv_format;
+ //
+ // Other inputs:
+ // bias (Q/K/V) : None or (3 * D)
+ // key_padding_mask (K/V) : None or (B, T)
+ // relative_position_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA.
+ //
+ // The following inputs are not used in cross attention (so they are None for cross attention):
+ // past_key : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True.
+ // For CUDA, past_present_share_buffer is always True. ROCm supports both.
+ // past_value : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True.
+ // For CUDA, past_present_share_buffer is always True. ROCm supports both.
+ // past_sequence_length : scalar (1) when past_present_share_buffer is True.
+ // CUDA version has extra inputs (beam_width, cache_indirection) that are not checked in the class.
+ // For ROCm, see contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh for more details.
+ // ---------------------------------------------------------------
+ AttentionQkvFormat qkv_format = UNKNOWN;
const auto& query_dims = query->Shape().GetDims();
- if (query_dims.size() != 3 && query_dims.size() != 5) {
+
+ int query_rank = static_cast(query_dims.size());
+ if (query_rank != 3 && query_rank != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ",
- query_dims.size());
+ query_rank);
}
int batch_size = static_cast(query_dims[0]);
int sequence_length = static_cast(query_dims[1]);
- int hidden_size = (query_dims.size() == 3)
+ bool dmmha_packing = operator_type == kDecoderMaskedMultiHeadAttention && key == nullptr && value == nullptr;
+ int hidden_size = (query_rank == 3)
? (dmmha_packing ? (static_cast(query_dims[2]) / 3) : static_cast(query_dims[2]))
: (num_heads * static_cast(query_dims[4]));
int head_size = static_cast(hidden_size) / num_heads;
int kv_sequence_length = sequence_length;
+ int v_hidden_size = hidden_size;
+ if (key != nullptr) {
+ if (value == nullptr) {
+ ORT_RETURN_IF_ERROR(Check_Q_KV(query, key, num_heads, head_size, qkv_format, kv_sequence_length));
+ } else {
+ ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, head_size,
+ qkv_format, kv_sequence_length, v_hidden_size));
+ }
+ } else if (value == nullptr) { // no key and value
+ ORT_RETURN_IF_ERROR(Check_QKV(query, qkv_format));
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'value' shall absent when 'key' is absent");
+ }
+
int past_sequence_length = 0;
int max_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
- const auto& past_key_dims = past_key->Shape().GetDims();
- const auto& past_value_dims = past_value->Shape().GetDims();
-
- if (past_key_dims.size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' is expected to have 4 dimensions, got ",
- past_key_dims.size());
- }
- if (past_value_dims.size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_value' is expected to have 4 dimensions, got ",
- past_value_dims.size());
- }
-
- if (past_key_dims[0] != batch_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' dimension 0 should be batch_size, got ",
- past_key_dims[0]);
- }
- if (past_value_dims[0] != batch_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_value' dimension 0 should be batch_size, got ",
- past_value_dims[0]);
- }
-
- if (past_key_dims[1] != num_heads) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' dimension 1 should be same as number of heads, got ",
- past_key_dims[1]);
- }
- if (past_value_dims[1] != num_heads) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_value' dimension 1 should be same as number of heads, got ",
- past_value_dims[1]);
- }
- if (past_key_dims[2] != past_value_dims[2]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ",
- past_key_dims[2], " vs ", past_value_dims[2]);
- }
- if (past_key_dims[3] != head_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' dimension 3 should be same as head_size, got ",
- past_key_dims[3]);
- }
- if (past_value_dims[3] != head_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_value' dimension 3 should be same as head_size, got ",
- past_value_dims[3]);
- }
- past_sequence_length = static_cast(past_key_dims[2]);
- max_sequence_length = static_cast(past_key_dims[2]);
- if (past_present_share_buffer) {
- if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "past_sequence_length tensor must be of one element when past_present_share_buffer is set");
- }
- past_sequence_length = *((*past_seq_len).template Data());
- }
+ ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, past_seq_len,
+ batch_size, num_heads, head_size, past_present_share_buffer,
+ past_sequence_length, max_sequence_length));
} else if (past_key != nullptr || past_value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall be both present or both absent");
}
- if (key != nullptr) {
- if (query_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ",
- query_dims.size());
- }
-
- const auto& key_dims = key->Shape().GetDims();
- if (key_dims.size() != 3 && key_dims.size() != 4 && key_dims.size() != 5) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3, 4, or 5 dimensions, got ",
- key_dims.size());
- }
- if (query_dims[0] != key_dims[0]) {
+ if (operator_type == kMultiHeadAttention) {
+ if (qkv_format == AttentionQkvFormat::QKV_BS3NH) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 0 (batch size)");
+ "Packed qkv of 3D BS3NH format is not support by MultiHeadAttention");
}
- if (key_dims.size() == 3) {
- if (key_dims[2] != query_dims[2]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 2 (hidden_size)");
- }
-
- qkv_format = Q_K_V_BSNH;
- kv_sequence_length = static_cast(key_dims[1]);
- } else if (key_dims.size() == 5) {
- if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
- }
- if (value != nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
- }
-
- qkv_format = Q_KV_BSNH_BSN2H;
- kv_sequence_length = static_cast(key_dims[1]);
- } else { // key_dims.size() == 4 (cross-attention with past_key)
- if (static_cast(key_dims[1]) != num_heads || static_cast(key_dims[3]) != head_size) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Expect 'key' shape (batch_size, num_heads, kv_sequence_length, head_size)");
- }
-
- if (value == nullptr || value->Shape().GetDims().size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' shall be 4D when 'key' is 4D");
- }
-
- if (bias != nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when 'key' is 4D");
- }
-
- qkv_format = UNKNOWN;
- kv_sequence_length = static_cast(key_dims[2]);
- }
- } else { // packed QKV
- if (query_dims.size() != 3 && query_dims.size() != 5) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions when key is empty, got ",
- query_dims.size());
- }
- if (query_dims.size() == 5 && (static_cast(query_dims[2]) != num_heads || static_cast(query_dims[3]) != 3)) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv");
+ if (qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H && bias != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when packed kv is used");
}
-
- qkv_format = QKV_BSN3H;
}
if (bias != nullptr) {
@@ -206,116 +384,31 @@ Status CheckInputs(const T* query,
bias_dims.size());
}
- if (value == nullptr) {
- // Currently, bias is not allowed for packed KV. This constraint can be removed later.
- // Here we assume that fusion tool will not include bias for packed KV.
- if (query_dims.size() == 5 && query_dims[3] == 2) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. ");
- }
+ int expected_bias_length = 2 * hidden_size + v_hidden_size;
+ if (bias_dims[0] != static_cast(expected_bias_length)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' length is expected to be 2 * hidden_size + hidden_size_v, got ",
+ bias_dims.size());
}
}
int total_sequence_length = past_sequence_length + kv_sequence_length;
AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
- mask_type = AttentionMaskType::MASK_UNKNOWN;
- const auto& mask_dims = key_padding_mask->Shape().GetDims();
- if (mask_dims.size() == 1) {
- if (mask_dims[0] == static_cast(batch_size)) {
- mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
- } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) {
- mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
- }
- } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) &&
- mask_dims[1] == static_cast(kv_sequence_length)) {
- mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
- } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) &&
- mask_dims[1] == static_cast(total_sequence_length)) {
- mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
- } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) &&
- mask_dims[1] == static_cast(sequence_length) &&
- mask_dims[2] == static_cast(total_sequence_length)) {
- mask_type = AttentionMaskType::MASK_3D_ATTENTION;
- }
-
+ mask_type = GetMaskType(key_padding_mask, batch_size, sequence_length, total_sequence_length);
if (mask_type == AttentionMaskType::MASK_UNKNOWN) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D");
- }
- }
-
- // NOTE: In Cross-Attention, we pass the past key and value to 'key' and 'value' instead of 'past_key' and 'past_value'.
- bool pass_past_in_kv = false;
- int v_hidden_size = hidden_size;
- if (value != nullptr) {
- const auto& value_dims = value->Shape().GetDims();
- if (value_dims.size() != 3 && value_dims.size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 or 4 dimensions, got ",
- value_dims.size());
- }
-
- if (query_dims[0] != value_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'value' shall have same dim 0 (batch_size)");
- }
-
- if (value_dims.size() == 3) {
- if (static_cast(kv_sequence_length) != value_dims[1]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
- }
- v_hidden_size = static_cast(value_dims[2]);
- } else { // value_dims.size() == 4
- if (static_cast(kv_sequence_length) != value_dims[2]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key' and 'value' shall have the same dim 2 (kv_sequence_length)");
- }
-
- if (past_key != nullptr || past_value != nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' and 'past_value' shall be empty when 'value' is 4D");
- }
-
- v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]);
- pass_past_in_kv = true;
+ "Input 'key_padding_mask' shape is not expected.");
}
}
bool broadcast_res_pos_bias = false;
if (relative_position_bias != nullptr) {
- const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
-
- if (relative_position_bias_dims.size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' is expected to have 4 dimensions, got ",
- relative_position_bias_dims.size());
- }
- if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ",
- relative_position_bias_dims[0]);
- }
- if (relative_position_bias_dims[0] == 1) {
- broadcast_res_pos_bias = true;
- }
- if (relative_position_bias_dims[1] != num_heads) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ",
- relative_position_bias_dims[1]);
- }
- if (relative_position_bias_dims[2] != sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ",
- relative_position_bias_dims[2]);
- }
- if (relative_position_bias_dims[3] != total_sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ",
- relative_position_bias_dims[3]);
- }
+ ORT_RETURN_IF_ERROR(CheckRelativePositionBias(
+ relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length, broadcast_res_pos_bias));
}
- // TODO: ORT_RETURN_IF(qkv_format == UNKNOWN, "Unrecognized QKV format");
+ assert(qkv_format != UNKNOWN);
+
if (parameters != nullptr) {
AttentionParameters* output_parameters = reinterpret_cast(parameters);
output_parameters->batch_size = batch_size;
@@ -323,7 +416,7 @@ Status CheckInputs(const T* query,
output_parameters->past_sequence_length = past_sequence_length;
output_parameters->kv_sequence_length = kv_sequence_length;
output_parameters->total_sequence_length = total_sequence_length;
- output_parameters->max_sequence_length = max_sequence_length;
+ output_parameters->max_sequence_length = past_present_share_buffer ? max_sequence_length : total_sequence_length;
output_parameters->input_hidden_size = 0;
output_parameters->hidden_size = hidden_size;
output_parameters->v_hidden_size = v_hidden_size;
@@ -336,7 +429,6 @@ Status CheckInputs(const T* query,
output_parameters->mask_type = mask_type;
output_parameters->scale = scale;
output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias;
- output_parameters->pass_past_in_kv = pass_past_in_kv;
output_parameters->qkv_format = qkv_format;
}
@@ -359,7 +451,7 @@ Status CheckInputs(const T* query,
float scale,
bool is_unidirectional,
bool past_present_share_buffer,
- bool dmmha_packing,
+ AttentionType operator_type,
int max_threads_per_block) {
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
@@ -367,7 +459,7 @@ Status CheckInputs(const T* query,
return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value,
past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional,
- past_present_share_buffer, dmmha_packing);
+ past_present_share_buffer, operator_type);
}
} // namespace multihead_attention_helper
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
index 995babc857357..5fdd2b017b8a6 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@@ -104,6 +104,8 @@ class MatMulNBits final : public OpKernel {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
+ const Tensor* tensor_zero_point = nullptr;
+ has_zp_input_ = info.TryGetConstantInput(3, &tensor_zero_point);
#ifdef ORT_NEURAL_SPEED
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
@@ -139,6 +141,7 @@ class MatMulNBits final : public OpKernel {
IAllocatorUniquePtr packed_b_{};
size_t packed_b_size_{0};
+ bool has_zp_input_{false};
#if defined(ORT_NEURAL_SPEED)
bool is_asym_{false};
@@ -207,10 +210,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}
-#else // defined(ORT_NEURAL_SPEED)
-
+#else // defined(ORT_NEURAL_SPEED)
+ ORT_UNUSED_PARAMETER(prepacked_weights);
+ const auto compute_type = static_cast(accuracy_level_);
if (input_idx == InputIndex::B) {
- const auto compute_type = static_cast(accuracy_level_);
if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
return Status::OK();
}
@@ -220,12 +223,20 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
}
auto qptr = tensor.DataRaw();
packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true);
- MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get());
- if (prepacked_weights) {
- prepacked_weights->buffers_.push_back(std::move(packed_b_));
- prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
- }
+ MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr);
is_packed = true;
+ } else if (compute_type == CompInt8) {
+#ifdef MLAS_TARGET_AMD64_IX86
+ if (input_idx == InputIndex::scales && packed_b_ != nullptr) {
+ auto sptr = tensor.Data();
+ MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr);
+ is_packed = false;
+ } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) {
+ auto zptr = tensor.Data();
+ MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr);
+ is_packed = false;
+ }
+#endif
}
#endif // defined(ORT_NEURAL_SPEED)
@@ -332,9 +343,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const auto* bias_data = bias == nullptr ? nullptr : bias->Data();
IAllocatorUniquePtr workspace{};
- if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
- nbits_, block_size_, compute_type);
- workspace_size > 0) {
+ const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(
+ M, N, K, batch_count, nbits_, block_size_, compute_type);
+ if (workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr(allocator, workspace_size);
@@ -344,14 +355,18 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
- data[i].QuantBData = packed_b_.get();
+#ifdef MLAS_TARGET_AMD64_IX86
+ if (compute_type == CompInt8) {
+ data[i].QuantBDataWorkspace = packed_b_.get();
+ }
+#endif
+ data[i].PackedQuantBData = static_cast(packed_b_.get());
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].Bias = bias_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}
-
MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
index 9e6752b451868..62d6a723bf32c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
@@ -520,6 +520,39 @@ __global__ void AddBiasUnpack(int M, const T* input, const T* biases, T* output)
}
}
+template
+__global__ void AddBiasTransposeUnpack(int M, const T* input, const T* biases, T* output) {
+ // Format 5 to unpack TRT packed input format to BNSH for unfused attention.
+ // Input: BxSxNxMxH
+ // Output: MxBxNxSxH
+ // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
+ int n = threadIdx.y;
+ int s = blockIdx.x;
+ int b = blockIdx.y;
+ int m = blockIdx.z; // matrix id
+
+ const int head_size = blockDim.x;
+ const int num_heads = blockDim.y;
+
+ const int sequence_length = gridDim.x;
+ const int batch_size = gridDim.y;
+ const int H = head_size;
+ const int NH = num_heads * head_size;
+ const int NHS = NH * sequence_length;
+
+ int in_offset = m * head_size + n * M * H + (s * NH + b * NHS) * M;
+ const int out_offset = (s + n * sequence_length) * head_size + (b + m * batch_size) * NHS;
+
+ const int h = threadIdx.x;
+ if (h < head_size) {
+ if (biases != nullptr) {
+ output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
+ } else {
+ output[out_offset + h] = input[in_offset + h];
+ }
+ }
+}
+
template
__global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) {
// Format 3 for cutlass memory efficient attention
@@ -692,6 +725,8 @@ void InvokeAddBiasTranspose(
}
} else if (format == 4) { // format == 4
AddBiasUnpack<<>>(total_matrix_count, input, biases, output);
+ } else if (format == 5) { // format == 5
+ AddBiasTransposeUnpack<<>>(total_matrix_count, input, biases, output);
} else { // format == 0
AddBiasTranspose<<>>(input, biases, output);
}
@@ -716,6 +751,8 @@ void InvokeAddBiasTranspose(
}
} else if (format == 4) { // format == 4
ORT_THROW("AddBiasTranspose (format 4) not implemented for hidden_size > max_threads_per_block");
+ } else if (format == 5) { // format == 5
+ ORT_THROW("AddBiasTranspose (format 5) not implemented for hidden_size > max_threads_per_block");
} else { // format 0
AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output);
}
@@ -904,6 +941,7 @@ void InvokeAddBias(
AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q);
}
}
+
// K
{
const dim3 grid(kv_sequence_length, batch_size, num_matrices);
@@ -1011,6 +1049,82 @@ void LaunchAddBias(
}
}
+template
+void InvokeAddBias(
+ cudaStream_t stream, const int max_threads_per_block,
+ const int batch_size, const int sequence_length,
+ const int num_heads, const int head_size,
+ const T* biases, const T* query, T* q) {
+ assert(num_heads <= max_threads_per_block);
+ constexpr int num_matrices = 1;
+ const dim3 grid(sequence_length, batch_size, num_matrices);
+ if (head_size * num_heads <= max_threads_per_block) {
+ const dim3 block(head_size, num_heads, 1);
+ AddBiasTransposeTrt<<>>(query, biases, q);
+ } else {
+ const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
+ AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q);
+ }
+}
+
+template <>
+void LaunchAddBias(
+ cudaStream_t stream, const int max_threads_per_block,
+ const int batch_size, const int sequence_length,
+ const int num_heads, const int head_size,
+ const float* biases, const float* query, float* q) {
+ if (0 == (head_size % 4)) {
+ const int H = head_size / 4;
+ const float4* query2 = reinterpret_cast(query);
+ const float4* biases2 = reinterpret_cast(biases);
+ float4* q2 = reinterpret_cast(q);
+ InvokeAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, num_heads, H,
+ biases2, query2, q2);
+ } else if (0 == (head_size & 1)) {
+ const int H = head_size / 2;
+ const float2* query2 = reinterpret_cast(query);
+ const float2* biases2 = reinterpret_cast(biases);
+ float2* q2 = reinterpret_cast(q);
+ InvokeAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, num_heads, H,
+ biases2, query2, q2);
+ } else {
+ InvokeAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, num_heads, head_size,
+ biases, query, q);
+ }
+}
+
+template <>
+void LaunchAddBias(
+ cudaStream_t stream, const int max_threads_per_block,
+ const int batch_size, const int sequence_length,
+ const int num_heads, const int head_size,
+ const half* biases, const half* query, half* q) {
+ if (0 == (head_size % 4)) {
+ const int H = head_size / 4;
+ const Half4* query2 = reinterpret_cast(query);
+ const Half4* biases2 = reinterpret_cast(biases);
+ Half4* q2 = reinterpret_cast(q);
+ InvokeAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, num_heads, H,
+ biases2, query2, q2);
+ } else if (0 == (head_size & 1)) {
+ const int H = head_size / 2;
+ const half2* query2 = reinterpret_cast(query);
+ const half2* biases2 = reinterpret_cast(biases);
+ half2* q2 = reinterpret_cast(q);
+ InvokeAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, num_heads, H,
+ biases2, query2, q2);
+ } else {
+ InvokeAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, num_heads, head_size,
+ biases, query, q);
+ }
+}
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
index efc31db43bcdb..bd4e123a272bc 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
@@ -3,14 +3,15 @@
#pragma once
#include "core/providers/cuda/shared_inc/cuda_utils.h"
+#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
-// Fused kernel of Add (bias) and Transpose.
+// Fused kernel of Add bias (optional, can be None) and Transpose.
// Shape of inputs and outputs:
-// biases: (num_matrices, num_heads * head_size)
+// biases: (num_matrices, num_heads * head_size) or None
// format 0: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3)
// input: (num_matrices, batch_size, sequence_length, num_heads, head_size)
// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
@@ -24,9 +25,12 @@ namespace cuda {
// format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3)
// input: (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (num_matrices, batch_size, sequence_length, num_heads, head_size)
-// format 4: (requires qk_head_size = v_head_size)
+// format 4: (requires qk_head_size == v_head_size)
// input: (batch_size, sequence_length, num_heads, num_matrices, head_size)
// output: (num_matrices, batch_size, sequence_length, num_heads, head_size)
+// format 5: (requires qk_head_size == v_head_size)
+// input: (batch_size, sequence_length, num_heads, num_matrices, head_size)
+// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
template
void LaunchAddBiasTranspose(
@@ -35,7 +39,7 @@ void LaunchAddBiasTranspose(
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr,
int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0);
-// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
+// Add bias (optional, can be None) and Transpose for separated inputs of Q, K and V, and output Trt format.
// For self attention:
// output: (batch_size, sequence_length, num_heads, 3, head_size)
// It assumes sequence_length == kv_sequence_length and head_size == v_head_size.
@@ -50,7 +54,7 @@ void LaunchAddBiasTransposeTrt(
const T* biases, const T* query, const T* key, const T* value, T* output,
bool is_cross_attention, int kv_sequence_length = -1);
-// Add (bias) for separated inputs of Q, K and V.
+// Add bias (required) for separated inputs of Q, K and V.
// Q: (batch_size, sequence_length, num_heads, head_size)
// K: (batch_size, kv_sequence_length, num_heads, head_size)
// V: (batch_size, kv_sequence_length, num_heads, v_head_size)
@@ -61,6 +65,46 @@ void LaunchAddBias(
const int num_heads, const int head_size, const int v_head_size,
const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v);
+// Add bias (required) for Q: (batch_size, sequence_length, num_heads, head_size)
+template
+void LaunchAddBias(
+ cudaStream_t stream, const int max_threads_per_block,
+ const int batch_size, const int sequence_length,
+ const int num_heads, const int head_size,
+ const T* biases, const T* query, T* q);
+
+// Add bias (optional, can be None) transpose kernel defined in packed_multihead_attention_impl.cu.
+// Support the following format transforms (for float and half only).
+// source_format => target_format:
+// Q_K_V_TNH => Q_K_V_BNSH (requires token_offset)
+// Q_K_V_TNH => Q_K_V_TNH
+// Q_K_V_TNH => QKV_TN3H
+// QKV_TN3H => Q_K_V_BNSH (requires token_offset)
+// QKV_TN3H => Q_K_V_TNH
+// QKV_TN3H => QKV_TN3H
+template
+void AddBiasTransposePacked(
+ const T* query, const T* key, const T* value, const T* bias, T* output,
+ const int batch_size, const int sequence_length,
+ const int num_heads, const int qk_head_size, const int v_head_size,
+ AttentionQkvFormat source_format, AttentionQkvFormat target_format,
+ const int32_t* token_offset, int32_t token_count,
+ cudaStream_t stream);
+
+// Add bias (required) transpose kernel defined in packed_attention_impl.cu.
+// Support the following format transforms (for float and half only):
+// format transform
+// Q_K_V_BNSH: Tx3xNxH => 3xBxNxSxH (requires token_offset)
+// Q_K_V_BSNH: Tx3xNxH => 3xTxNxH
+// QKV_BSN3H: Tx3xNxH => TxNx3xH
+template
+void AddBiasTransposePacked(
+ const T* input, const T* biases, T* output,
+ const int batch_size, const int sequence_length,
+ const int num_heads, const int qk_head_size, const int v_head_size,
+ AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count,
+ cudaStream_t stream);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc
index 3b7f980ba1881..5c0989bced70c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc
@@ -260,7 +260,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
fused_runner,
use_flash_attention,
use_fused_cross_attention,
- use_memory_efficient_attention);
+ use_memory_efficient_attention,
+ false);
IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream());
typedef typename ToCudaType::MappedType CudaT;
@@ -281,6 +282,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
}
data.has_qkv_workspace = true;
data.workspace = reinterpret_cast(work_space.get());
+ data.workspace_bytes = workSpaceSize;
data.output = reinterpret_cast(output->MutableData());
if (nullptr != present) {
data.present = reinterpret_cast(present->MutableData());
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index 997493acd9cb7..f9eabe27d97e4 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -58,31 +58,25 @@ size_t AlignSize(size_t bytes) {
return bytesAligned;
}
-void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) {
- if (this->sequence_length != seq_length) {
- ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0);
- LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr,
- this->max_batch_size, seq_length, stream);
- this->sequence_length = seq_length;
+const int32_t* CumulatedSequenceLengthCache::TryGet(int batch_size, int32_t seq_len, cudaStream_t stream) {
+ if (this->sequence_length == 0 && seq_len > 0) {
+ // Initialize only once with sequence length in the first request.
+ std::call_once(init_once_flag_, [&]() {
+ ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0);
+ LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr,
+ this->max_batch_size, seq_len, stream);
+ // Syncronize to ensure thread-safe since other thread will not wait for the above kernel finish.
+ // Otherwise, the data might be consumed by other threads before it is ready and causes data race issue.
+ cudaStreamSynchronize(stream);
+ this->sequence_length = seq_len;
+ });
}
-}
-int* GetCumulatedSequenceLength(CumulatedSequenceLengthCache* cache,
- const int* mask_index,
- int batch_size,
- int sequence_length,
- cudaStream_t stream,
- void* scratch_buffer) {
- if (mask_index == nullptr && cache != nullptr) {
- if (batch_size <= cache->max_batch_size) {
- cache->Initialize(sequence_length, stream);
- return reinterpret_cast(cache->buffer.get());
- }
+ if (this->sequence_length == seq_len && batch_size <= this->max_batch_size) {
+ return reinterpret_cast(buffer.get());
}
- int* sequence_offset = reinterpret_cast(scratch_buffer);
- LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, sequence_length, stream);
- return sequence_offset;
+ return nullptr;
}
size_t GetAttentionScratchSize(
@@ -114,10 +108,12 @@ size_t GetAttentionWorkspaceSize(
void* fused_runner,
bool use_flash_attention,
bool use_fused_cross_attention,
- bool use_memory_efficient_attention) {
+ bool use_memory_efficient_attention,
+ bool no_qkv_workspace) {
// Note that q, k and v might need alignment for fused attention kernels.
- const size_t qkv_bytes = element_size * batch_size * num_heads *
- ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size);
+ const size_t qkv_size = element_size * batch_size * num_heads *
+ ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size);
+ const size_t qkv_bytes = no_qkv_workspace ? 0 : qkv_size;
#if USE_FLASH_ATTENTION
if (use_flash_attention) {
@@ -162,39 +158,44 @@ Status FusedTrtCrossAttention(
// We only enable fused cross attention when there is no key padding mask.
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
assert(data.mask_index == nullptr);
-
+ assert(data.scratch != nullptr);
+ assert(data.q != nullptr);
+ assert(data.k != nullptr);
+
+#ifndef NDEBUG
+ char* scratch_end = reinterpret_cast(data.scratch) + 2 * GetSequenceOffsetSize(parameters.batch_size, false);
+ char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes;
+ assert(scratch_end <= buffer_end);
+#endif
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
- data.mask_index, batch_size,
- sequence_length, stream,
- data.scratch);
+ int32_t* q_sequence_offset = const_cast(data.cumulated_sequence_length_q_cache);
+ if (q_sequence_offset == nullptr) {
+ q_sequence_offset = reinterpret_cast(data.scratch);
+ LaunchTrtSequenceOffset(q_sequence_offset, data.mask_index, batch_size, sequence_length, stream);
+ }
+
+ CUDA_RETURN_IF_ERROR(cudaGetLastError());
DUMP_TENSOR_INIT();
DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1);
- int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int));
- kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache,
- data.mask_index, batch_size, parameters.kv_sequence_length, stream,
- kv_sequence_offset);
- CUDA_RETURN_IF_ERROR(cudaGetLastError());
+ int32_t* kv_sequence_offset = const_cast(data.cumulated_sequence_length_kv_cache);
+ if (kv_sequence_offset == nullptr) {
+ int* scratch = reinterpret_cast(data.scratch) + (GetSequenceOffsetSize(batch_size, false) / sizeof(int));
+ kv_sequence_offset = reinterpret_cast(scratch);
+ LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, parameters.kv_sequence_length, stream);
+ }
+ CUDA_RETURN_IF_ERROR(cudaGetLastError());
DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1);
FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel =
reinterpret_cast(data.fused_cross_attention_kernel);
- // When there is no bias, we can directly use q and packed kv from inputs.
- void const* query = data.q;
- void const* packed_kv = data.k;
- if (data.value == nullptr && data.bias == nullptr) {
- query = data.query;
- packed_kv = data.key;
- }
-
run_fused_cross_attention(
- query, // Q
- packed_kv, // packed KV
+ data.q, // Q
+ data.k, // packed KV
q_sequence_offset, // cumulated sequence length of Q
kv_sequence_offset, // cumulated sequence length of KV
data.output, // output
@@ -206,8 +207,6 @@ Status FusedTrtCrossAttention(
parameters.kv_sequence_length, // sequence length of KV
stream);
- DUMP_TENSOR("trt cross output", data.output,
- batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
return Status::OK();
}
@@ -225,24 +224,33 @@ Status FusedTrtSelfAttention(
cudaStream_t stream,
contrib::AttentionParameters& parameters,
AttentionData& data) {
+ assert(data.scratch != nullptr);
+#ifndef NDEBUG
+ char* scratch_end = reinterpret_cast(data.scratch) + GetSequenceOffsetSize(parameters.batch_size, false);
+ char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes;
+ assert(scratch_end <= buffer_end);
+#endif
+
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const bool causal = parameters.is_unidirectional;
- int* sequence_offset = reinterpret_cast(data.scratch);
-
- DUMP_TENSOR_INIT();
+ const int32_t* sequence_offset = data.cumulated_sequence_length_q_cache;
if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
- DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length);
- LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
+ LaunchTrtSequenceOffset2d(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream);
+ sequence_offset = reinterpret_cast(data.scratch);
} else {
- sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
- data.mask_index, batch_size, sequence_length, stream,
- sequence_offset);
+ if (sequence_offset == nullptr) {
+ LaunchTrtSequenceOffset(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream);
+ sequence_offset = reinterpret_cast(data.scratch);
+ }
}
- DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1);
+
CUDA_RETURN_IF_ERROR(cudaGetLastError());
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1);
+
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner);
const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length);
@@ -252,22 +260,12 @@ Status FusedTrtSelfAttention(
if (!causal) {
assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H);
-
- // When there is no bias, we can directly use packed qkv from inputs.
- void const* packed_qkv = data.q;
- if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) {
- packed_qkv = data.query;
- }
-
- fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream);
- DUMP_TENSOR("fused output", data.output,
- batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
+ fused_fp16_runner->Run(b, s, data.q, sequence_offset, data.output, stream);
} else {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream);
- DUMP_TENSOR("fused causal output", data.output,
- batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
}
+
return Status::OK();
}
@@ -289,38 +287,19 @@ Status FlashAttention(
contrib::AttentionParameters& parameters,
AttentionData& data,
float scale) {
- assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
+ data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
assert(nullptr == data.mask_index);
assert(nullptr == data.relative_position_bias);
assert(parameters.head_size == parameters.v_head_size);
- void* query = reinterpret_cast(data.q);
- void* key = reinterpret_cast(data.k);
- void* value = reinterpret_cast(data.v);
- // For packed KV, we can use query input directly.
- if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) {
- query = reinterpret_cast(const_cast(data.query));
- }
-
- DUMP_TENSOR_INIT();
- DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query),
- parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size);
- DUMP_TENSOR_D("k(BSNH)", data.k,
- parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size);
- DUMP_TENSOR_D("v(BSNH)", data.v,
- parameters.batch_size, parameters.total_sequence_length,
- parameters.num_heads, parameters.v_head_size);
-
- bool is_bf16 = false;
+ constexpr bool is_bf16 = false;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
- device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch),
+ device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch),
parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16,
parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum),
- true));
-
- DUMP_TENSOR("flash attention output", data.output,
- parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size);
+ data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH));
return Status::OK();
}
@@ -351,25 +330,8 @@ Status EfficientAttention(
float scale) {
// We only enable fused cross attention when there is no key padding mask.
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
- assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
-
- const void* query = data.q;
- const void* key = data.k;
- const void* value = data.v;
- // For packed KV, we can use query input directly.
- if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) {
- assert(data.bias == nullptr);
- query = data.query;
- }
-
- DUMP_TENSOR_INIT();
- DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query),
- parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size);
- DUMP_TENSOR_D("k(BSNH)", data.k,
- parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size);
- DUMP_TENSOR_D("v(BSNH)", data.v,
- parameters.batch_size, parameters.total_sequence_length,
- parameters.num_heads, parameters.v_head_size);
+ assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
+ data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
@@ -394,21 +356,19 @@ Status EfficientAttention(
? nullptr
: const_cast(reinterpret_cast(
data.mask_index + 2 * parameters.batch_size + 1));
- p.query = query;
- p.key = key;
- p.value = value;
+ p.query = data.q;
+ p.key = data.k;
+ p.value = data.v;
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
- p.is_kv_bsnh = true;
+ p.is_kv_bsnh = data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH;
p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
? data.scratch
: nullptr;
p.stream = stream;
p.has_custom_right_padding = false;
run_memory_efficient_attention(p);
- DUMP_TENSOR("efficient attention output", data.output,
- parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size);
return Status::OK();
}
@@ -449,10 +409,6 @@ Status UnfusedAttention(
cublasSetStream(cublas, stream);
- DUMP_TENSOR_INIT();
- DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size);
-
const int present_sequence_length = parameters.past_present_share_buffer
? parameters.max_sequence_length
: total_sequence_length;
@@ -467,8 +423,7 @@ Status UnfusedAttention(
&zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches,
device_prop, parameters.use_tf32));
- DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length);
+ DUMP_TENSOR_INIT();
DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length);
constexpr size_t element_size = sizeof(T);
@@ -523,7 +478,6 @@ Status UnfusedAttention(
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
device_prop.maxThreadsPerBlock, false, temp_output, data.output);
- DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size);
return result;
}
@@ -554,7 +508,7 @@ Status QkvToContext(
if (!parameters.past_present_share_buffer) {
ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size,
- sequence_length, total_sequence_length, parameters.pass_past_in_kv,
+ sequence_length, total_sequence_length,
stream, max_threads_per_block, data));
} else { // past_present_share_buffer
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
index 56836bdda197c..fad353dcfeb07 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
@@ -6,6 +6,8 @@
#include
#include
#include
+#include
+#include
#include "core/framework/allocator.h"
#include "contrib_ops/cpu/bert/attention_common.h"
@@ -15,13 +17,18 @@ namespace cuda {
constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128;
+// A cache for cumulated sequence length. It will be initialized in the first request, then become read-only after that.
struct CumulatedSequenceLengthCache {
onnxruntime::IAllocatorUniquePtr buffer;
int32_t max_batch_size;
int32_t sequence_length;
- CumulatedSequenceLengthCache() : max_batch_size(0), sequence_length(0) {}
- void Initialize(int32_t sequence_length, cudaStream_t stream);
+ CumulatedSequenceLengthCache() : max_batch_size(kCumulatedSequenceLengthCacheMaxBatchSize), sequence_length(0) {}
+
+ const int32_t* TryGet(int batch_size, int32_t sequence_length, cudaStream_t stream);
+
+ // Use this flag to guard the initializaton only once in multi-threading.
+ mutable std::once_flag init_once_flag_;
};
size_t
@@ -46,7 +53,8 @@ size_t GetAttentionWorkspaceSize(
void* fused_runner,
bool use_flash_attention,
bool use_fused_cross_attention,
- bool use_memory_efficient_attention);
+ bool use_memory_efficient_attention,
+ bool no_qkv_workspace);
template
struct AttentionData {
@@ -65,8 +73,6 @@ struct AttentionData {
bool has_qkv_workspace = false;
T* workspace = nullptr;
- T* temp_k_workspace = nullptr;
- T* temp_v_workspace = nullptr;
T* output = nullptr;
T* present = nullptr;
@@ -79,22 +85,50 @@ struct AttentionData {
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
- mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr;
- mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr;
+ const int32_t* cumulated_sequence_length_q_cache = nullptr;
+ const int32_t* cumulated_sequence_length_kv_cache = nullptr;
// Intermediate data
T* q = nullptr;
T* k = nullptr;
T* v = nullptr;
T* scratch = nullptr;
- AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ AttentionQkvFormat qkv_format = AttentionQkvFormat::UNKNOWN;
// Flash buffers
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;
T* out_accum = nullptr;
+
+ // For Debugging
+ size_t workspace_bytes = 0;
+ bool allow_debug_info = false;
+
+ bool IsUnfused() const {
+ return !use_flash_attention && !use_memory_efficient_attention &&
+ (fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr);
+ }
+
+ void PrintDebugInfo() const {
+ std::cout << "flash=" << use_flash_attention
+ << ", efficient=" << use_memory_efficient_attention
+ << ", fused_runner=" << (fused_runner != nullptr)
+ << ", fused_cross=" << (fused_cross_attention_kernel != nullptr)
+ << ", bias=" << (bias != nullptr)
+ << ", attn_bias=" << (relative_position_bias != nullptr)
+ << ", mask_dims=" << mask_index_dims.size()
+ << ", has_qkv_workspace=" << has_qkv_workspace
+ << ", workspace=" << workspace_bytes
+ << ", past=" << (past != nullptr ? 1 : (past_key != nullptr ? 2 : 0))
+ << ", present=" << (present != nullptr ? 1 : (present_key != nullptr ? 2 : 0))
+ << std::endl;
+ }
};
+// Return true if it does not need qkv workspace, false otherwise.
+template
+bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data);
+
template
Status PrepareQkv(contrib::AttentionParameters& parameters,
AttentionData& data,
@@ -129,6 +163,9 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output,
int total_matrix_count = -1);
+Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
+ const float* input, float* output, cudaStream_t stream, const int max_threads_per_block);
+
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const half* input, half* output, cudaStream_t stream, const int max_threads_per_block);
@@ -158,7 +195,7 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
template
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
- int sequence_length, int total_sequence_length, bool pass_past_in_kv,
+ int sequence_length, int total_sequence_length,
cudaStream_t stream,
int max_threads_per_block,
AttentionData& data);
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
index bd7df5f490c76..aba1e01bfd91b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
@@ -50,6 +50,7 @@ class AttentionKernelOptions {
bool use_unfused_{true};
bool use_trt_flash_attention_{true};
+
bool use_trt_cross_attention_{true};
// Causal attention is disabled by default in #14732.
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu
index 89be0f1115f41..9f0f49348c225 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu
@@ -249,16 +249,15 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
template
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
- int sequence_length, int total_sequence_length, bool pass_past_in_kv,
- cudaStream_t stream,
- int max_threads_per_block,
+ int sequence_length, int total_sequence_length,
+ cudaStream_t stream, int max_threads_per_block,
AttentionData& data) {
// Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length.
// past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH)
// past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
// When there is past state, the head size for Q/K/V shall be same: H == H_v.
- if (nullptr != data.present) {
+ if (nullptr != data.present) { // Attention op
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
@@ -270,58 +269,52 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
// Update pointers to present_k and present_v.
data.k = data.present;
data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
- } else if (nullptr != data.past_key || nullptr != data.present_key) {
- if (nullptr != data.past_key && nullptr == data.present_key) {
- data.k = const_cast(data.past_key);
- data.v = const_cast(data.past_value);
- } else if (nullptr == data.past_key && nullptr != data.present_key) {
- if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
+ } else { // MultiHeadAttention op
+ if (nullptr != data.present_key) {
+ ORT_ENFORCE(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH ||
+ data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
+ if (nullptr != data.past_key) {
+ assert(data.past_key != data.k);
+ assert(data.past_value != data.v);
+
+ ORT_RETURN_IF_ERROR(
+ LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
+ batch_size, qk_head_size, num_heads,
+ max_threads_per_block, 1, data.past_key, data.k, data.present_key));
+ ORT_RETURN_IF_ERROR(
+ LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
+ batch_size, v_head_size, num_heads,
+ max_threads_per_block, 1, data.past_value, data.v, data.present_value));
+ // Update pointers to present_k and present_v.
data.k = data.present_key;
data.v = data.present_value;
- } else {
- assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
- data.k = data.temp_k_workspace;
- data.v = data.temp_v_workspace;
+ } else { // nullptr == data.past_key && nullptr != data.present_key
+ if (data.k != data.present_key) {
+ int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size;
+ cudaMemcpyAsync(data.present_key, data.k, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
+ }
+
+ if (data.v != data.present_value) {
+ int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size;
+ cudaMemcpyAsync(data.present_value, data.v, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
+ }
}
- } else if (pass_past_in_kv) {
- // past_key and past_value are used directly as key and value in attention computations
- data.k = const_cast(data.past_key);
- data.v = const_cast(data.past_value);
-
- // This path has a memory copy from past_key and past_value to present_key and present_value
- // Avoid this path since the memory copy is unnecessary because past_key == present_key and
- // past_value == present_value
- int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size;
- int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size;
- cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
- cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
- } else {
- ORT_RETURN_IF_ERROR(
- LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
- batch_size, qk_head_size, num_heads,
- max_threads_per_block, 1, data.past_key, data.k, data.present_key));
- ORT_RETURN_IF_ERROR(
- LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
- batch_size, v_head_size, num_heads,
- max_threads_per_block, 1, data.past_value, data.v, data.present_value));
- // Update pointers to present_k and present_v.
- data.k = data.present_key;
- data.v = data.present_value;
}
}
+
return CUDA_CALL(cudaGetLastError());
}
// Template Instantiation
template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
- int sequence_length, int total_sequence_length, bool pass_past_in_kv,
+ int sequence_length, int total_sequence_length,
cudaStream_t stream,
int max_threads_per_block,
AttentionData& data);
template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
- int sequence_length, int total_sequence_length, bool pass_past_in_kv,
+ int sequence_length, int total_sequence_length,
cudaStream_t stream,
int max_threads_per_block,
AttentionData& data);
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
index 040d6124e7456..05c592ec61059 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
@@ -12,12 +12,101 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
+#if DEBUG_TENSOR_LEVEL > 1
+// Dump the workspace for Q, K, V after processing QKV data.
+template
+void DumpQkv(AttentionData& data) {
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
+ DUMP_TENSOR_INIT();
+ if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
+ DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size);
+ DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size);
+ } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
+ DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size);
+ } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) {
+ DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size);
+ DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size);
+ } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) {
+ DUMP_TENSOR_D("q(BSN3H)", data.q, batch_size, sequence_length, num_heads * 3, qk_head_size);
+ }
+}
+
+// Dump the inputs before processing QKV data.
+template
+void DumpInputs(contrib::AttentionParameters& parameters, AttentionData& data) {
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
+ DUMP_TENSOR_INIT();
+ if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
+ DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size);
+ DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size);
+ } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
+ DUMP_TENSOR_D("Query(BSNH)", data.query, batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("Key(BSNH)", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("Value(BSNH)", data.value, batch_size, kv_sequence_length, num_heads, v_head_size);
+ } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) {
+ DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size);
+ DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size);
+ } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) {
+ DUMP_TENSOR_D("Query(BSN3H)", data.query, batch_size, sequence_length, num_heads * 3, qk_head_size);
+ } else if (data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H) {
+ DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("Value(BSN2H)", data.value, batch_size, sequence_length, num_heads * 2, qk_head_size);
+ }
+
+ if (data.bias != nullptr) {
+ DUMP_TENSOR_D("Q_bias", data.bias, num_heads, qk_head_size);
+ DUMP_TENSOR_D("K_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
+ DUMP_TENSOR_D("V_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
+ }
+
+ if (data.relative_position_bias != nullptr) {
+ DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias,
+ parameters.broadcast_res_pos_bias ? 1 : batch_size,
+ num_heads, sequence_length, kv_sequence_length);
+ }
+
+ if (data.mask_index != nullptr) {
+ if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
+ DUMP_TENSOR_D("mask", data.mask_index, batch_size, parameters.total_sequence_length);
+ }
+ if (parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
+ DUMP_TENSOR_D("mask", data.mask_index, 3 * batch_size + 2, 1);
+ }
+ }
+}
+
+// Dump the kernel outputs
+template
+void DumpOutputs(AttentionData& data) {
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR("output", data.output,
+ parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size);
+}
+#endif
+
template
Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
AttentionData& data,
cudaStream_t stream,
- int max_threads_per_block,
- AttentionQkvFormat& qkv_format) {
+ int max_threads_per_block) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
@@ -40,7 +129,7 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.gemm_buffer, qkv, 3));
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
} else {
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
@@ -48,13 +137,13 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
// For fused causal kernel, use format 1 since we need have K and V to update present state,
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
- qkv_format = use_fused_kernel
- ? AttentionQkvFormat::QKV_BSN3H
- : (use_flash_or_efficient_attention
- ? AttentionQkvFormat::Q_K_V_BSNH
- : (use_fused_causal
- ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
- : AttentionQkvFormat::Q_K_V_BNSH));
+ data.qkv_format = use_fused_kernel
+ ? AttentionQkvFormat::QKV_BSN3H
+ : (use_flash_or_efficient_attention
+ ? AttentionQkvFormat::Q_K_V_BSNH
+ : (use_fused_causal
+ ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
+ : AttentionQkvFormat::Q_K_V_BNSH));
// For fused causal, we will update gemm_buffer with bias directly.
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
@@ -71,367 +160,526 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
return Status::OK();
}
-// For MultiHeadAttention with past state
+// Return true if the workspace is not needed for Q, K, V inputs, false otherwise.
+// This shall be in sync with the following function PrepareQkv_MHA_Cross.
template
-Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters,
- AttentionData& data,
- cudaStream_t stream,
- int max_threads_per_block,
- T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
+bool NoQkvWorkspace_MHA_Cross(AttentionData& data) {
+ // query, key and value are passed as Q, K and V for the following conditions.
+ return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr);
+}
+
+// For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format)
+template
+Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block) {
+ assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
+ // past_key or past_value is not supported for cross attention
+ // present_key and present_value can be supported in theory, although we do not allow the senario for now.
+ assert(data.past_key == nullptr);
+ assert(data.past_value == nullptr);
+ assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_Cross(data));
+
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
-
- DUMP_TENSOR_INIT();
- if (data.bias == nullptr) {
- // Below logic does not support fused attention with past without bias
- // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past.
-
- // cross attention with past state
- if (data.past_key != nullptr && data.present_key == nullptr) {
- assert(data.past_value != nullptr);
- assert(data.query != nullptr);
- assert(data.key == nullptr);
- assert(data.value == nullptr);
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, q));
+#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
+ if (data.use_memory_efficient_attention || data.use_flash_attention) {
+ // Add bias for Q
+ if (data.bias != nullptr) {
+ LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size,
+ data.bias, data.query, data.q);
+ } else {
+ data.q = const_cast(data.query);
}
- // cross attention with present state or self attention with present state
- else if (data.past_key == nullptr && data.present_key != nullptr) {
- assert(data.past_value == nullptr);
- assert(data.present_value != nullptr);
- assert(data.query != nullptr);
- assert(data.key != nullptr);
- assert(data.value != nullptr);
-
- // TODO: supporting packed qkv for self attention may benefit performance
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, q));
- // TODO: supporting packed kv for cross attention may benefit performance
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, data.present_key));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, data.present_value));
- }
- // self attention with past and present state
- else {
- assert(data.past_key != nullptr);
- assert(data.past_value != nullptr);
- assert(data.present_key != nullptr);
- assert(data.present_value != nullptr);
- assert(data.query != nullptr);
- assert(data.key != nullptr);
- assert(data.value != nullptr);
- // TODO: supporting packed qkv for self attention may benefit performance
+ // Here we have assumption that there is no bias for key and value when they are in BNSH format.
+ data.k = const_cast(data.key);
+ data.v = const_cast(data.value);
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
+ } else
+#endif
+ { // unfused kernel
+ assert(data.IsUnfused());
+ if (data.bias == nullptr) {
+ // Transpose query from BSNH to BNSH
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, q));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, k));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, v));
+ max_threads_per_block, false, data.query, data.q));
+ } else {
+ // Add bias to query, and transpose it: Query (BxSxNxH) => Q (BxNxSxH)
+ constexpr int format = 0;
+ LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
+ batch_size, sequence_length, num_heads, qk_head_size,
+ data.query, data.bias, data.q,
+ true, -1);
}
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+
+ // Here we have assumption that there is no bias for key and value when they are in BNSH format.
+ // So we do not need to add bias for key and value. Just use the key and value directly.
+ data.k = const_cast(data.key);
+ data.v = const_cast(data.value);
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ }
+ return Status::OK();
+}
+
+template
+bool NoQkvWorkspace_MHA_NoPast(AttentionData& data) {
+ // query, key and value are passed as Q, K and V for the following conditions.
+ return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr;
+}
+
+// For MultiHeadAttention without past state, with Q, K and V inputs
+template
+Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block) {
+ assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ assert(data.query != nullptr);
+ assert(data.key != nullptr);
+ assert(data.value != nullptr);
+ assert(data.past_key == nullptr);
+ assert(data.past_value == nullptr);
+ assert(data.present_key == nullptr);
+ assert(data.present_value == nullptr);
+ assert(!parameters.is_unidirectional);
+ assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data));
+
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
+ if (data.fused_cross_attention_kernel != nullptr) {
+ assert(qk_head_size == v_head_size);
+ assert(data.relative_position_bias == nullptr);
+ assert(data.mask_index == nullptr);
+ assert(parameters.hidden_size == parameters.v_hidden_size);
+
+ // For fused cross attention, besides adding bias, K and V needed to be packed:
+ // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
+ LaunchAddBiasTransposeTrt(
+ stream, max_threads_per_block,
+ batch_size, sequence_length,
+ num_heads, qk_head_size,
+ data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
+ data.v = nullptr;
+ data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
}
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
- // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value
- else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
- data.past_key != nullptr &&
- data.past_value != nullptr &&
- parameters.pass_past_in_kv) {
- // Transpose past_key and past_value to use memory efficient attention
-
- // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH)
- ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.past_key, data.temp_k_workspace));
- // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
- ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.past_value, data.temp_v_workspace));
-
- // query => q, temp_k_workspace => k, temp_v_workspace => v
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
-
- DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
-
- data.past_key = nullptr;
- data.past_value = nullptr;
+ else if (data.use_memory_efficient_attention || data.use_flash_attention) {
+ if (data.bias != nullptr) {
+ LaunchAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, kv_sequence_length,
+ num_heads, qk_head_size, v_head_size,
+ data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
+ } else {
+ data.q = const_cast(data.query);
+ data.k = const_cast(data.key);
+ data.v = const_cast(data.value);
+ }
+
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
}
- // When there is no past_key/past_value and there is present_key/present_value
- // (e.g. get initial kv to use as past_kv in the next iteration)
- else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
- data.present_key != nullptr &&
- data.present_value != nullptr) {
- // Use memory efficient attention kernel
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace);
-
- // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.temp_k_workspace, data.present_key));
+#endif
+ else if (data.fused_runner != nullptr) {
+ assert(qk_head_size == v_head_size);
+ assert(data.relative_position_bias == nullptr);
+
+ // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
+ LaunchAddBiasTransposeTrt(
+ stream, max_threads_per_block,
+ batch_size, sequence_length,
+ num_heads, qk_head_size,
+ data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
+ data.k = nullptr;
+ data.v = nullptr;
+
+ data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
+ } else { // unfused kernel
+ assert(data.IsUnfused());
+ // Query (BxSxNxH) => Q (BxNxSxH)
+ constexpr int format = 0;
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, sequence_length, num_heads, qk_head_size,
+ data.query, data.bias, data.q,
+ true, -1);
+
+ // Key (BxLxNxH) => K (BxNxLxH)
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, qk_head_size,
+ data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
+ true, -1);
- // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
+ // Value (BxLxNxH_v) => K (BxNxLxH_v)
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, v_head_size,
+ data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
+ true, -1);
+
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ }
+
+ return Status::OK();
+}
+
+template
+bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData& data) {
+ if (data.use_memory_efficient_attention || data.use_flash_attention) {
+ // Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV.
+ return data.past_key == nullptr && data.present_key != nullptr;
+ }
+ return false;
+}
+
+// For MultiHeadAttention with kv cache (past or present), but no bias
+template
+Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block) {
+ assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ assert(data.query != nullptr);
+ assert(data.key != nullptr);
+ assert(data.value != nullptr);
+ assert(data.bias == nullptr);
+ assert(data.fused_runner == nullptr);
+ assert(data.fused_cross_attention_kernel == nullptr);
+ assert(data.present_key != nullptr);
+ assert(data.present_value != nullptr);
+ assert(data.past_key == nullptr && data.past_value == nullptr ||
+ data.past_key != nullptr && data.past_value != nullptr);
+ assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data));
+
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
+ // When there is no past state and there is present state, we output K and V directly to present state.
+ if (data.past_key == nullptr && data.present_key != nullptr) {
+ data.k = data.present_key;
+ data.v = data.present_value;
+ }
+
+#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
+ if (data.use_memory_efficient_attention || data.use_flash_attention) {
+ // Use oiginal Query (BSNH) since there is no bias.
+ data.q = const_cast(data.query);
+
+ // Key (BxLxNxH) => K (BxNxLxH)
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.key, data.k));
+ // Value (BxLxNxH) => V (BxNxLxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.temp_v_workspace, data.present_value));
+ max_threads_per_block, false, data.value, data.v));
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
+ } else
+#endif
+ { // unfused kernel
+ assert(data.IsUnfused());
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.query, data.q));
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.key, data.k));
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
+ max_threads_per_block, false, data.value, data.v));
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ }
+
+ return Status::OK();
+}
- DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+template
+constexpr bool NoQkvWorkspace_MHA_WithPast_Bias(AttentionData& /*data*/) {
+ return false;
+}
+
+// For MultiHeadAttention with both kv cache (past or present) and bias
+template
+Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block) {
+ assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ assert(data.bias != nullptr);
+ assert(!(data.past_key != nullptr && data.present_key == nullptr));
+ assert(data.fused_runner == nullptr);
+ assert(data.fused_cross_attention_kernel == nullptr);
+ assert(data.present_key != nullptr);
+ assert(data.present_value != nullptr);
+ assert(data.past_key == nullptr && data.past_value == nullptr ||
+ data.past_key != nullptr && data.past_value != nullptr);
+ assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_Bias(data));
+
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
+ // When there is no past state and there is present state, we output K and V directly to present state.
+ if (data.past_key == nullptr && data.present_key != nullptr) {
+ data.k = data.present_key;
+ data.v = data.present_value;
}
+
+#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
+ if (data.use_memory_efficient_attention || data.use_flash_attention) {
+ // Query(BxSxNxH) + Bias_Q => Q (BxSxNxH)
+ LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size,
+ data.bias, data.query, data.q);
+
+ // Key (BxLxNxH) + Bias_K => K (BxNxLxH)
+ constexpr int format = 0;
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, qk_head_size,
+ data.key, data.bias + num_heads * qk_head_size, data.k, true, -1);
+
+ // Key (BxLxNxH) + Bias_K => K (BxNxLxH)
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, v_head_size,
+ data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1);
+
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
+ } else
#endif
- else {
- // Use unfused kernel for Q, use unfused kernel for K and V if needed
+ { // unfused kernel
+ assert(data.IsUnfused());
+
constexpr int format = 0;
// Query (BxSxNxH) => Q (BxNxSxH)
LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
- data.query, data.bias, q,
+ data.query, data.bias, data.q,
true, -1);
- if (!parameters.pass_past_in_kv) {
- T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k;
- T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v;
-
- // Key (BxLxNxH) => K (BxNxLxH)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, qk_head_size,
- data.key, data.bias + num_heads * qk_head_size, k_dest,
- true, -1);
+ // Key (BxLxNxH) => K (BxNxLxH)
+ LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, qk_head_size,
+ data.key, data.bias + num_heads * qk_head_size, data.k,
+ true, -1);
- // Value (BxLxNxH_v) => V (BxNxLxH_v)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, v_head_size,
- data.value, data.bias + 2 * num_heads * qk_head_size, v_dest,
- true, -1);
+ // Value (BxLxNxH_v) => V (BxNxLxH_v)
+ LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, v_head_size,
+ data.value, data.bias + 2 * num_heads * qk_head_size, data.v,
+ true, -1);
- DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size);
- DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size);
- }
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
+
return Status::OK();
}
+template
+bool NoQkvWorkspace_MHA_PackedQKV(AttentionData& data) {
+ // query, key and value are passed as Q, K and V for the following conditions.
+ return nullptr != data.fused_runner && data.bias == nullptr;
+}
+
// For MultiHeadAttention without past state, with packed QKV inputs
template