From 70b8cda979e0ffbb0dcfb17e9a66c623281a4080 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 2 Nov 2023 17:16:35 -0700 Subject: [PATCH] Cherry pick LLaMA to rel-1.16.2 (round 2) (#18245) 2nd round of cherry pick LLaMA related changes to 1.16.2 release. --------- Co-authored-by: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Co-authored-by: Frank Dong <123416088+frank-dong-ms@users.noreply.github.com> --- cmake/onnxruntime_rocm_hipify.cmake | 5 + docs/ContribOperators.md | 6 +- .../contrib_ops/cuda/bert/attention_impl.cu | 2 + .../bert/cutlass_fmha/fmha_launch_template.h | 54 +- .../cutlass_fmha/memory_efficient_attention.h | 2 + .../cuda/bert/group_query_attention.cc | 73 ++- .../cuda/bert/group_query_attention.h | 1 + .../cuda/bert/group_query_attention_helper.h | 88 ++- .../cuda/bert/group_query_attention_impl.cu | 596 ++++++++++++++---- .../cuda/bert/group_query_attention_impl.h | 9 + .../cuda/bert/packed_attention_impl.cu | 2 + .../bert/packed_multihead_attention_impl.cu | 2 + .../core/graph/contrib_ops/bert_defs.cc | 6 +- .../python/tools/symbolic_shape_infer.py | 1 + .../tools/transformers/convert_generation.py | 10 +- .../python/tools/transformers/fusion_base.py | 5 + .../transformers/fusion_rotary_attention.py | 320 +++++++++- .../tools/transformers/models/llama/README.md | 36 +- .../transformers/models/llama/benchmark.py | 43 +- .../models/llama/benchmark_70b_model.sh | 12 + .../models/llama/benchmark_all.py | 11 +- .../models/llama/convert_70b_model.sh | 12 + .../models/llama/convert_to_onnx.py | 448 ++++++++----- .../models/llama/dist_settings.py | 45 ++ .../transformers/models/llama/llama_inputs.py | 10 +- .../transformers/models/llama/llama_parity.py | 53 +- .../transformers/models/llama/llama_torch.py | 38 ++ .../models/llama/requirements-70b-model.txt | 4 + .../python/tools/transformers/onnx_model.py | 12 + .../python/transformers/test_flash_attn.py | 309 ++++++--- .../transformers/test_rotary_mha_fusion.py | 451 ++++++++++++- 31 files changed, 2102 insertions(+), 564 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh create mode 100644 onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh create mode 100644 onnxruntime/python/tools/transformers/models/llama/dist_settings.py create mode 100644 onnxruntime/python/tools/transformers/models/llama/llama_torch.py create mode 100644 onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 53078df2b13f3..64c4ca7d1543a 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -113,6 +113,11 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" + "bert/group_query_attention_helper.h" + "bert/group_query_attention.h" + "bert/group_query_attention.cc" + "bert/group_query_attention_impl.h" + "bert/group_query_attention_impl.cu" ) if (NOT onnxruntime_ENABLE_ATEN) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 66a215aa1e535..43816c3db8133 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2265,14 +2265,14 @@ This version of the operator has been available since version 1 of the 'com.micr
When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.
-#### Outputs (1 - 3) +#### Outputs
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
-
present_key (optional) : T
+
present_key : T
present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
present_value (optional) : T
+
present_value : T
present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index eb9e6d5c62467..16ce3a899fb5e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -374,6 +374,7 @@ Status EfficientAttention( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.total_sequence_length; + p.max_sequence_length = parameters.total_sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; @@ -395,6 +396,7 @@ Status EfficientAttention( p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index ed330b0fca332..51c3d3d3a458b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromTopLeft; + p.custom_mask_type = Attention::CausalFromBottomRight; } - // Input format is BxSxNxH, output is BxSxNxH - p.q_strideH = params.qk_head_size; - p.k_strideH = params.qk_head_size; - p.v_strideH = params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; - - p.q_strideM = params.num_heads * params.qk_head_size; - p.k_strideM = params.num_heads * params.qk_head_size; - p.v_strideM = params.num_heads * params.v_head_size; - p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; - - p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; - p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length; - p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } } constexpr auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index f725be8d7cf89..f16567bb6f2b7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -14,10 +14,12 @@ namespace cuda { struct MemoryEfficientAttentionParams { int32_t sm; bool is_half; + bool is_kv_bsnh = true; int32_t batch_size; int32_t num_heads; int32_t sequence_length; int32_t kv_sequence_length; + int32_t max_sequence_length; int32_t qk_head_size; int32_t v_head_size; bool causal; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 67d750aeac11a..8694dc998c7a8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -6,9 +6,8 @@ #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" #include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" -// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" -// #include "contrib_ops/cpu/utils/console_dumper.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -55,6 +54,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_flash_attention_ = true; #endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif } template @@ -92,18 +98,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { output_shape[2] = static_cast(parameters.hidden_size); Tensor* output = context->Output(0, output_shape); - std::vector present_dims; - if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - present_dims = { - parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; - } else { // BNSH - present_dims = { - parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; - } - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, @@ -143,8 +137,47 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto seqlens_k_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif - // only kernel implemented for gqa right now - ORT_ENFORCE(use_flash_attention); +#if USE_MEMORY_EFFICIENT_ATTENTION + int sm = (device_prop.major * 10) + device_prop.minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + (parameters.head_size & 7) == 0 && + parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + has_memory_efficient_attention(sm, sizeof(T) == 2); + // allocate buffers + size_t kv_buffer_bytes = 0; + // need a buffer if we must ungroup kv + const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); + if (use_memory_efficient_attention && needs_buff) { + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size); + } + size_t fmha_buffer_bytes = 0; + if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + } + auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); +#else + constexpr bool use_memory_efficient_attention = false; + auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); +#endif + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; + } + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); data.key = reinterpret_cast(key->Data()); @@ -155,6 +188,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; if (softmax_lse_buffer != nullptr) { data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); } @@ -167,6 +201,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (seqlens_k_buffer != nullptr) { data.seqlens_k = reinterpret_cast(seqlens_k_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 72c9814fad670..a90418ec2243a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel { bool is_past_bsnh_; float scale_; bool disable_flash_attention_; + bool disable_memory_efficient_attention_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index be8f5ca0ae3e9..8c21de9ced058 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query, // query (Q) : (B, S, D) // key (K) : (B, S+, D_kv) // value (V) : (B, S+, D_kv) + ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = Q_K_V_BSNH; const auto& query_dims = query->Shape().GetDims(); const auto& key_dims = key->Shape().GetDims(); - const auto& value_dims = value->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query, int q_hidden_size = static_cast(query_dims[2]); int head_size = static_cast(q_hidden_size) / num_heads; - int kv_sequence_length = sequence_length; - int kv_hidden_size = (key_dims.size() == 3) - ? static_cast(key_dims[2]) - : (kv_num_heads * static_cast(key_dims[3])); + int kv_sequence_length = static_cast(key_dims[1]); + int kv_hidden_size = static_cast(key_dims[2]); int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { @@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent"); } - if (key != nullptr) { - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - if (key_dims[2] != value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)"); - } - - qkv_format = Q_K_V_BSNH; - kv_sequence_length = static_cast(key_dims[1]); - } else { + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing key tensor."); + "Input 'query' and 'key' shall have same dim 0 (batch size)"); } - if (value != nullptr) { - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } - if (static_cast(kv_sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); - } + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - } else { + if (static_cast(kv_sequence_length) != value_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing value tensor."); + "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); + } + + if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); } // When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly. int32_t past_sequence_length = 0; - int present_sequence_length = 0; + int present_sequence_length = kv_sequence_length; if (past_seq_len != nullptr) { + if (past_key == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past KV must be present as share-buffer when using past_seq_len pointer."); + } if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "past_sequence_length tensor must be of one element when using past kv."); @@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query, } else { past_sequence_length = static_cast(*((*past_seq_len).template Data())); } + if (past_sequence_length + kv_sequence_length > max_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length"); + } present_sequence_length = max_sequence_length; } else if (past_key != nullptr) { past_sequence_length = max_sequence_length; // this is the length of past_key tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index ab3029ca34886..0455825c364a2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -37,6 +37,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" @@ -47,6 +48,8 @@ namespace onnxruntime { namespace contrib { namespace cuda { +////////// Auxiliary Kernels for KV prep + // Kernel for seqlens_k __global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { int id = blockDim.x * blockIdx.x + threadIdx.x; @@ -75,7 +78,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const int present_head_stride = is_bsnh ? H : present_seqlen * H; // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH + // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L const int past_seqlen = present_seqlen - new_seqlen; @@ -95,33 +98,32 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, } } +// Use when (H*)*num_heads > 1024 template __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int H, + const int num_heads, const T* past_kv, const T* new_kv, T* present_kv, const bool is_bsnh) { - // Use when (H*)*num_heads > 1024 - int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int present_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = present_seqlen - new_seqlen; - const int present_seqlen = gridDim.x; - const int num_heads = blockDim.y; - const int thread_stride = blockDim.x; - - const int present_batch_stride = present_seqlen * num_heads * H; - const int row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_seqlen * H; - - // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH - // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = present_seqlen - new_seqlen; - - while (h < H) { int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { const int past_batch_stride = past_seqlen * num_heads * H; @@ -135,133 +137,477 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; present_kv[out_offset] = new_kv[in_offset]; } - h += thread_stride; } } +// Concat new to past in present. Supports past BSNH or past BNSH template -Status QkvToContext( +Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = (H * kv_num_heads + 255) / 256; + const dim3 grid(steps, present_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to append new kv to kv buffer in place +template +__global__ void ConcatKVInPlace(const int past_seqlen, + const int present_seqlen, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int new_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; +} + +template +__global__ void ConcatKVInPlaceLarge(const int past_seqlen, + const int present_seqlen, + const int H, + const int num_heads, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int new_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; + } +} + +// Concat new to kv buffer in place +template +Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(kv_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, kv_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, + const int in_seqlen, + const int kv_num_heads, + const bool is_bsnh) { + const int h = threadIdx.x; + const int out_n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int out_seqlen = gridDim.x; + const int q_num_heads = blockDim.y; + const int H = blockDim.x; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + const int in_n = out_n / q_kv_head_ratio; + + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; +} + +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, + const int H, + const int in_seqlen, + const int q_num_heads, + const int kv_num_heads, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements + if (i < H * q_num_heads) { + const int out_seqlen = gridDim.y; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + + const int h = i % H; + const int out_n = i / H; + const int in_n = out_n / q_kv_head_ratio; + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; + } +} + +// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. +Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 grid(buff_seqlen, batch_size, 1); + const dim3 block(H, num_heads, 1); + Ungroup<<>>(k_og, + k_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + Ungroup<<>>(v_og, + v_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + } else { + int steps = int(ceil(float(H * num_heads) / 256.0)); + const dim3 grid(steps, buff_seqlen, batch_size); + const dim3 block(256, 1, 1); + UngroupLarge<<>>(k_og, + k_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + UngroupLarge<<>>(v_og, + v_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + } + return CUDA_CALL(cudaGetLastError()); +} + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +Status FlashAttention( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, + cudaStream_t stream, contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data) { - assert(data.use_flash_attention); + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; -#if USE_FLASH_ATTENTION - auto stream = static_cast(ort_stream->GetHandle()); + void* query = reinterpret_cast(const_cast(data.query)); + void* key = reinterpret_cast(const_cast(data.key)); + void* value = reinterpret_cast(const_cast(data.value)); + + bool is_causal = parameters.is_unidirectional; + + if (data.past_key != nullptr && data.past_key == data.present_key) { + // Share buffer case + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + // Launch kernel to copy seqlen + int thr_per_blk = 256; + int blk_in_grid = ceil(float(batch_size) / thr_per_blk); + repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), + reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, kv_sequence_length, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum))); + + } else { + // Not share buffer or no past (prompt generation) + // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), + batch_size, num_heads, kv_num_heads, head_size, + sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.kv_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; const int present_sequence_length = parameters.present_sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(head_size)) : parameters.scale; - if (data.use_flash_attention) { - assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - assert(parameters.num_heads % parameters.kv_num_heads == 0); - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - - bool is_causal = parameters.is_unidirectional; - - if (data.past_key == nullptr && data.present_key == nullptr) { - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(data.softmax_lse), - parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, head_size, - parameters.sequence_length, parameters.kv_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum))); - - } else if (data.past_key == data.present_key) { - // Assume past and present kv share buffer. - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - assert(parameters.past_sequence_length >= 0); - assert(data.past_value != nullptr); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", data.seqlens_k, 1, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); - - } else if (data.present_key != nullptr && (data.past_key != nullptr || kv_sequence_length == present_sequence_length)) { - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (head_size % 4 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "requires head_size be divisible by 4"); - } - const int H = head_size / 4; - if (H * kv_num_heads <= max_threads_per_block) { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(H, kv_num_heads, 1); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } else { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), - batch_size, num_heads, kv_num_heads, head_size, - sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + const void* query = reinterpret_cast(data.query); + const void* key = reinterpret_cast(data.key); + const void* value = reinterpret_cast(data.value); + if (data.past_key != nullptr) { + // Past key case + // concatenate new kv to past kv + if (data.past_key == data.present_key) { + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); } + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, past_sequence_length + kv_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + } else if (num_heads == kv_num_heads) { + // no past or present and no need to ungroup... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // intermediate buffer so q and kv have same num heads... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, kv_sequence_length, + kv_sequence_length, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH, stream, + max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = sequence_length; + p.kv_sequence_length = past_sequence_length + kv_sequence_length; + p.max_sequence_length = (num_heads == kv_num_heads) ? present_sequence_length : past_sequence_length + kv_sequence_length; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr; + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr; + p.is_attn_bias_batched = false; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +////////// API Functions + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif - return Status::OK(); +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); } #endif + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 0bad9eeb61231..8412631078e6a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -14,19 +14,28 @@ namespace cuda { template struct GroupQueryAttentionData { + // Input Tensors const T* query = nullptr; const T* key = nullptr; const T* value = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; + // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; int* seqlens_k = nullptr; + // Memory Efficient buffers + T* fmha_buffer = nullptr; + T* k = nullptr; + T* v = nullptr; + // Output Tensors T* output = nullptr; T* present_key = nullptr; T* present_value = nullptr; + // Kernel Flags bool use_flash_attention = false; + bool use_memory_efficient_attention = false; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index aba0efdbd7d5f..d7aeef1501cd6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -507,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; + p.is_kv_bsnh = true; p.batch_size = parameters.batch_size; p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e09fd9e6b36e5..3fe9dbf8ed34a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -688,6 +688,7 @@ Status FusedAttentionCutlass( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; @@ -702,6 +703,7 @@ Status FusedAttentionCutlass( p.attn_bias = data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v))) : nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 365416cdb75e6..db32cb3c05de1 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1041,15 +1041,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .Output(2, "present_value", "present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 8ac059f7fc4d3..985608f02741e 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -147,6 +147,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GatherElements": self._infer_GatherElements, "GatherND": self._infer_GatherND, "Identity": self._pass_on_shape_and_type, + "AllReduce": self._pass_on_shape_and_type, "If": self._infer_If, "Loop": self._infer_Loop, "MatMul": self._infer_MatMul, diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index b32ae64c5b0c0..7aca5e8526a23 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,7 +1272,7 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0): +def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1): past_seq_len = past_seq_len_input if past_seq_len not in model.get_graphs_input_names(): # Add model input for past sequence length @@ -1282,6 +1282,10 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads # Replace MultiHeadAttention with GroupQueryAttention for node in model.model.graph.node: if node.op_type == "MultiHeadAttention": + num_heads_mha = 0 + for att in node.attribute: + if att.name == "num_heads": + num_heads_mha = att.i gqa_node = onnx.helper.make_node( "GroupQueryAttention", inputs=[ @@ -1295,8 +1299,8 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads outputs=node.output, name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), domain="com.microsoft", - num_heads=node.attribute[0].i, - kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads, + num_heads=num_heads_mha // world_size, + kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size, is_past_bsnh=0, ) model.model.graph.node.remove(node) diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index c5d7bc16d64f7..67f4f0b55cff8 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -130,3 +130,8 @@ def add_nodes_to_remove(self, nodes: List[NodeProto]): for node in nodes: if node not in self.nodes_to_remove: self.nodes_to_remove.append(node) + + def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]): + for node in nodes: + if node not in self.nodes_to_remove and node not in nodes_to_keep: + self.nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 44d15b619ec7a..ceee836e33f77 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -323,6 +323,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # qkv_nodes_1 is for LLaMA-2 Microsoft # qkv_nodes_2 is for LLaMA-2 Hugging Face + # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model qkv_nodes = None qkv_nodes_1 = self.model.match_parent_path( normalize_node, @@ -334,18 +335,27 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["MatMul", "Reshape", "Transpose", "MatMul"], [1, 0, 0, 0], ) + qkv_nodes_3 = self.model.match_parent_path( + normalize_node, + ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0, 0], + ) if qkv_nodes_1 is not None: _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1 qkv_nodes = qkv_nodes_1 elif qkv_nodes_2 is not None: _, reshape_qkv, _, matmul_qkv = qkv_nodes_2 qkv_nodes = qkv_nodes_2 + elif qkv_nodes_3 is not None: + _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3 + qkv_nodes = qkv_nodes_3 else: logger.debug("fuse_rotary_attention: failed to match qkv nodes") return # v_nodes_1 is for LLaMA-2 Microsoft # v_nodes_3 is for LLaMA-2 Hugging Face + # v_nodes_4 is for LLaMA-2 70B model past_v, present_v, past_seq_len = "", "", "" v_nodes = None v_nodes_1 = self.model.match_parent_path( @@ -363,6 +373,118 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape", "MatMul"], [1, 0, 0], ) + _, v_nodes_4, _ = self.model.match_parent_paths_all( + matmul_qkv, + [ + ( + ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 2, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 3, 0, 0, 0, 1, 0, 0], + ), + ], + output_name_to_node=None, + ) if v_nodes_1 is not None: reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 v_nodes = v_nodes_1 @@ -388,6 +510,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): transpose_v, reshape_v, matmul_v = v_nodes_3 v_nodes = v_nodes_3 present_v = transpose_v.output[0] + elif v_nodes_4 is not None and len(v_nodes_4) == 9: + concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:] + v_nodes = v_nodes_4 + past_v = concat_v.input[0] + present_v = concat_v.output[0] else: logger.debug("fuse_rotary_attention: failed to match v path") return @@ -461,6 +588,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # k_nodes_1 is for LLaMA-2 Microsoft # k_nodes_2 is for LLaMA-2 Hugging Face + # k_nodes_4 is for LLaMA-2 70B Hugging Face past_k, present_k = "", "" k_nodes = None k_nodes_1 = self.model.match_parent_path( @@ -478,6 +606,174 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0, 0], ) + _, k_nodes_4, _ = self.model.match_parent_paths_all( + matmul_qk, + [ + ( + [ + "Transpose", + "Reshape", + "Expand", + "Unsqueeze", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ], + output_name_to_node=None, + ) if k_nodes_1 is not None: reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 k_nodes = k_nodes_1 @@ -505,6 +801,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes = k_nodes_3 past_k = concat_k.input[0] present_k = concat_k.output[0] + elif k_nodes_4 is not None and len(k_nodes_4) == 9: + reshape_k, matmul_k = k_nodes_4[0][-2:] + concat_k, rotary_k = k_nodes_4[0][-5:-3] + k_nodes = k_nodes_4 + past_k = concat_k.input[0] + present_k = concat_k.output[0] else: logger.debug("fuse_rotary_attention: failed to match k nodes") return @@ -552,7 +854,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return root_output = reshape_qkv_2.output[0] - elif qkv_nodes == qkv_nodes_2: + elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3): if not self.check_runtime_shape_paths_for_nodes( reshape_qkv, reshape_q, @@ -573,6 +875,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key) rotary_k.output[0] = rotary_k.name + "_output_0" + if qkv_nodes == qkv_nodes_3: + qkv_nodes = qkv_nodes[1:] + new_node = self.create_mha_node( matmul_q.input[0], root_output, @@ -594,7 +899,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.nodes_to_remove.extend(qkv_nodes[1:]) - self.nodes_to_remove.extend(v_nodes[:-1]) + + if v_nodes != v_nodes_4: + self.nodes_to_remove.extend(v_nodes[:-1]) + else: + nodes_to_keep = [v_nodes[0][-1]] + for temp_path in v_nodes: + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) + self.nodes_to_remove.extend(qk_nodes) if k_nodes == k_nodes_1: @@ -608,6 +920,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.append(k_nodes[1]) self.nodes_to_remove.append(k_nodes[3]) self.nodes_to_remove.append(k_nodes[4]) + elif k_nodes == k_nodes_4: + nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]] + for temp_path in k_nodes: + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) if q_nodes == q_nodes_1: self.nodes_to_remove.extend(q_nodes[:-2]) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 9619e6cb52a91..1bb6940d1cd74 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -10,6 +10,8 @@ Please note the package versions needed for using LLaMA-2 in the `requirements.t - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. - `requirements-quant.txt` - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) +- `requirements-70b-model.txt` + - For running the LLaMA-2 70B model on multiple GPUs - `requirements.txt` - Package versions needed in each of the above files @@ -79,6 +81,15 @@ model.save_pretrained(name.split("/")[-1] + "-onnx") Here are some additional examples for exporting LLaMA-2. +Export Model with Different GPU Device Ids +``` +# From source using first GPU: +$ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b + +# From wheel using second GPU: +$ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b +``` + Export Saved Model on Disk ``` # From source: @@ -153,6 +164,19 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu ``` +Export LLaMA-2 70B sharded model into 4 partitions +``` +# From source: +# 1. Install necessary packages from requirements-70b-model.txt + +# 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command: +$ ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ + +# 3. Shard and export the LLaMA-2 70B model. With FP16, you will need at least 140GB of GPU memory to load the model. Therefore, you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command: +$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda + +``` + ## Benchmark LLaMA-2 Here are some examples of how you can benchmark LLaMA-2. @@ -220,11 +244,11 @@ python3 -m models.llama.benchmark \ --device cuda ``` -6. ONNX Runtime, FP32, convert_to_onnx +6. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ - --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ + --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ --batch-sizes "1 2" \ @@ -232,11 +256,11 @@ python3 -m models.llama.benchmark \ --device cpu ``` -7. ONNX Runtime, FP16, convert_to_onnx +7. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ - --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ + --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index 245ff3dfe7f9d..be678931de5d1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,6 +11,7 @@ import onnx import psutil import torch +from dist_settings import get_rank, get_size from llama_inputs import ( add_io_bindings, get_merged_sample_with_past_kv_inputs, @@ -133,6 +134,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): use_fp16=args.use_fp16, engine="ort", return_dict=True, + world_size=args.world_size, ) iter_inputs = get_merged_sample_with_past_kv_inputs( args.config, @@ -144,6 +146,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): use_fp16=args.use_fp16, engine="ort", return_dict=True, + world_size=args.world_size, ) elif args.benchmark_type == "ort-msft": @@ -244,10 +247,10 @@ def get_model(args: argparse.Namespace): if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx - logger.info(f"Loading model from {args.ort_model_path}") + logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}") start_time = time.time() model = ort.InferenceSession( - args.ort_model_path, + args.ort_model_path.format(args.rank), sess_options, providers=[args.execution_provider], ) @@ -315,10 +318,11 @@ def time_fn(args, fn, inputs): latency = total_time / args.num_runs throughput = args.batch_size / latency - logger.info(f"Batch Size: {args.batch_size}") - logger.info(f"Sequence Length: {args.sequence_length}") - logger.info(f"Latency: {latency} s") - logger.info(f"Throughput: {throughput} tps") + if args.rank == 0: + logger.info(f"Batch Size: {args.batch_size}") + logger.info(f"Sequence Length: {args.sequence_length}") + logger.info(f"Latency: {latency} s") + logger.info(f"Throughput: {throughput} tps") return @@ -358,7 +362,8 @@ def measure_fn(args, fn, inputs): process.cpu_percent(interval=0.1) fn(inputs) - logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%") + if args.rank == 0: + logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%") # Measure memory usage gc.collect() @@ -451,7 +456,7 @@ def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Add IO bindings for non-CPU execution providers if args.device != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings( - model, inputs, args.device, int(args.device_id), kv_cache_ortvalues + model, inputs, args.device, int(args.rank), kv_cache_ortvalues ) setattr(args, "io_binding", io_binding) # noqa: B010 return io_binding, kv_cache_ortvalues @@ -511,7 +516,7 @@ def run_inference(args, init_inputs, iter_inputs, model): raise Exception(f"Cannot recognize {args.benchmark_type}") -def get_args(): +def get_args(rank=0): parser = argparse.ArgumentParser() parser.add_argument( "-bt", @@ -569,7 +574,7 @@ def get_args(): parser.add_argument( "-s", "--sequence-lengths", - default="8 16 32 64 128 256 512", + default="32 64 128 256 512", ) parser.add_argument( "-d", @@ -606,9 +611,9 @@ def get_args(): if "ort" in args.benchmark_type: setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010 if args.execution_provider == "CUDAExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = (args.execution_provider, {"device_id": rank}) elif args.execution_provider == "ROCMExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = (args.execution_provider, {"device_id": rank}) args.device = "cuda" # Check that paths have been specified for any benchmarking with ORT @@ -635,14 +640,19 @@ def get_args(): def main(): - args = get_args() + rank = get_rank() + world_size = get_size() + + args = get_args(rank) setup_logger(args.verbose) logger.info(args.__dict__) torch.backends.cudnn.benchmark = True + args.rank = rank + args.world_size = world_size tokenizer = LlamaTokenizer.from_pretrained(args.model_name) config = LlamaConfig.from_pretrained(args.model_name) - target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device + target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device use_fp16 = args.precision == "fp16" setattr(args, "tokenizer", tokenizer) # noqa: B010 @@ -656,7 +666,7 @@ def main(): # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA) if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}: - onnx_model = onnx.load_model(args.ort_model_path, load_external_data=False) + onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False) gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" @@ -666,7 +676,8 @@ def main(): # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): - logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") + if args.rank == 0: + logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") setattr(args, "batch_size", int(batch_size)) # noqa: B010 setattr(args, "sequence_length", int(sequence_length)) # noqa: B010 diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh new file mode 100644 index 0000000000000..38f1916456658 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +NUM_GPUS=${1:-1} + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode $NUM_GPUS --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python benchmark.py ${@:2}" + +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py index 951b2549368f7..b35a5e27f9ea3 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -247,6 +247,7 @@ def main(): torch.backends.cudnn.benchmark = True all_results = [] + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) # Benchmark PyTorch without torch.compile if args.hf_pt_eager: @@ -266,8 +267,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -298,8 +297,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -332,8 +329,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -366,8 +361,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -399,8 +392,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh new file mode 100644 index 0000000000000..637d15c10e0c7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +NUM_GPUS=${1:-1} + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode $NUM_GPUS --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python convert_to_onnx.py ${@:2}" + +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 3f05be53c6729..b0e0b41e75d3d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -1,16 +1,16 @@ import argparse import logging import os -import tempfile +import shutil from itertools import chain from typing import List import onnx import torch -from benchmark_helper import Precision, prepare_environment, setup_logger -from convert_generation import replace_mha_with_gqa +from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check +from llama_torch import setup_torch_model from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version @@ -18,8 +18,11 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer +from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger +from onnxruntime.transformers.convert_generation import replace_mha_with_gqa logger = logging.getLogger("") +init_dist() def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): @@ -129,7 +132,9 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st # del onnx_model # temp_dir.cleanup() # -def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_dynamo_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): from torch._dynamo import config config.capture_scalar_outputs = True @@ -150,9 +155,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_model_fp32.onnx.data") + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data") del onnx_model os.system( f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" @@ -160,7 +165,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll # Export decoder_with_past_model.onnx input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length + l_config, device, batch_size, sequence_length, world_size=world_size ) temp_dir = args.output # tempfile.TemporaryDirectory() temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") @@ -172,9 +177,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_with_past_model_fp32.onnx.data") + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data") del onnx_model os.system( f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" @@ -183,10 +188,21 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") -def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def _prepare_dir(dir_path): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + +def run_torchscript_separate_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): # Dummy values for export batch_size, sequence_length = 2, 8 - device = torch.device("cpu") + + # set device used to export model + # for llama-2-70b we will use current gpus to speed up export process + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") # Export decoder_model.onnx decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length) @@ -199,8 +215,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon ), ] dynamic_axes = get_model_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_inputs, @@ -218,18 +238,25 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) # Export decoder_with_past_model.onnx - decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length) + decoder_with_past_inputs = get_sample_with_past_kv_inputs( + l_config, + device, + batch_size, + sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, + ) input_names = [ "input_ids", "attention_mask", @@ -247,8 +274,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon ), ] dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_past_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_with_past_inputs, @@ -266,27 +297,45 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_with_past_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) - logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + logger.info( + f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!" + ) -def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_torchscript_merged_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): # Dummy values for export batch_size, sequence_length, past_sequence_length = 2, 8, 0 - device = torch.device("cpu") + + # set device used to export model + # for llama-2-70b we will use current gpus to speed up export process + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") + + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 # Export decoder_merged_model.onnx decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length, past_sequence_length + l_config, + device, + batch_size, + sequence_length, + past_sequence_length, + max_seq_len=max_sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, ) input_names = [ "input_ids", @@ -305,8 +354,12 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi ), ] dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_merged_inputs, @@ -324,17 +377,17 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_merged_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) - logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + logger.info(f"The {args.model_name} merged ONNX model has been successfully created with the TorchScript exporter!") # Optimize the model as FP32 @@ -357,12 +410,16 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str): remove_existing_model(input_path) -def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str]): - decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx") +def convert_to_float16( + args: argparse.Namespace, config: LlamaConfig, old_paths: List[str], rank: int = 0, world_size: int = 1 +): + decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx") decoder_with_past_model_fp16_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx" + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx" + ) + decoder_merged_model_fp16_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx" ) - decoder_merged_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp16.onnx") new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path] logger.info("Converting to float16...") @@ -370,7 +427,7 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: if os.path.exists(fp32_path): model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True)) model.convert_float_to_float16(keep_io_types=False) - model = use_group_query_attention(config, model) + model = use_group_query_attention(config, model, world_size) model.save_model_to_file(fp16_path, use_external_data_format=True) del model logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!") @@ -380,9 +437,11 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: return new_paths -def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel): +def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1): # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes - fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads) + fp16_model_opt = replace_mha_with_gqa( + fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size + ) fp16_model_opt.prune_graph() fp16_model_opt.update_graph(allow_remove_graph_inputs=True) return fp16_model_opt @@ -406,7 +465,7 @@ def smooth_quant( calibration_sampling_size=[args.calibration_sampling_size], recipes={ "optypes_to_exclude_output_quant": ["MatMul"], - "smooth_quant": args.smooth_quant, + "smooth_quant": True, "smooth_quant_args": {"alpha": args.smooth_quant_alpha}, }, op_type_dict={ @@ -526,15 +585,6 @@ def get_args(): help="Execution provider to verify parity with", ) - parser.add_argument( - "-id", - "--device-id", - required=False, - type=str, - default="0", - help="Device ID for GPUs", - ) - parser.add_argument( "-r", "--reexport", @@ -655,6 +705,14 @@ def get_args(): ) parser.set_defaults(use_dynamo_export=False) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + args = parser.parse_args() return args @@ -673,144 +731,182 @@ def main(): remove_existing_files(args.output) logger.info(f"Arguments: {args}") + world_size = get_size() + rank = get_rank() + # Load model and config use_auth_token = args.input == os.path.join(".") setattr(args, "use_auth_token", use_auth_token) # noqa: B010 - location = args.model_name if use_auth_token else args.input - l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) - llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, use_cache=True) original_model_name = args.model_name setattr(args, "original_model_name", original_model_name) # noqa: B010 args.model_name = args.model_name.split("/")[-1] - # Set model paths for FP32 model - decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") - decoder_with_past_model_fp32_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx" - ) - decoder_merged_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - - missing_separate_exports = ( - args.no_merged - and not os.path.exists(decoder_model_fp32_path) - and not os.path.exists(decoder_with_past_model_fp32_path) - ) - missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) - - # Export to ONNX - if missing_separate_exports or missing_merged_export: - if args.use_dynamo_export and missing_separate_exports: - logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") - logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") - logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") - logger.warning( - "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" - ) - logger.warning( - "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." - ) - run_dynamo_export(args, l_config, llama) - elif args.no_merged: - run_torchscript_separate_export(args, l_config, llama) - else: - run_torchscript_merged_export(args, l_config, llama) - del llama # Delete LLaMA model from memory since it will be loaded again during parity check + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 + setattr(args, "device", torch.device(args.device_name)) # noqa: B010 - # Set model paths to store FP32 optimized model - decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") - decoder_with_past_model_fp32_opt_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp32_opt.onnx" - ) - decoder_merged_model_fp32_opt_path = os.path.join( - args.output, f"{args.model_name}_decoder_merged_model_fp32_opt.onnx" - ) - new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path] + location = args.original_model_name if use_auth_token else args.input - # Run the optimizer script - logger.info("Optimizing models...") - for orig_path, opt_path in zip(old_paths, new_paths): - if os.path.exists(orig_path): - optimize_export(l_config, input_path=orig_path, output_path=opt_path) + # use cuda for Llama-2-70b to speedup export, other models use CPU by default + l_config, llama = setup_torch_model( + args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None + ) - # Re-assign default FP32 model paths as their optimized versions - decoder_model_fp32_path = decoder_model_fp32_opt_path - decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path - decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + assert l_config.num_attention_heads % world_size == 0 and l_config.num_key_value_heads % world_size == 0 - logger.info( - f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" - ) - - # Change precision of exported models from FP32 - if args.precision == Precision.FLOAT16: - new_paths = convert_to_float16(args, l_config, old_paths) - - elif args.precision == Precision.INT8: - decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx") - decoder_with_past_model_int8_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx" - ) - decoder_merged_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int8.onnx") - new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] - - if args.quantization_method == "smooth_quant": - if not args.no_merged: - logger.error("SmoothQuant must be used on separately exported models") - else: - logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8") - smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) - - elif args.quantization_method == "quantize_dynamic": - logger.warning( - "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + barrier() + for i in range(world_size): + if i == rank: + # Set model paths for FP32 model + decoder_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx" + ) + decoder_with_past_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx" + ) + decoder_merged_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx" ) + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - logger.info("Quantizing to int8...") - for fp32_path, int8_path in zip(old_paths, new_paths): - if os.path.exists(fp32_path): - ort_quantization.quantize_dynamic( - fp32_path, - int8_path, - op_types_to_quantize=["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"], - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, + missing_separate_exports = ( + args.no_merged + and not os.path.exists(decoder_model_fp32_path) + and not os.path.exists(decoder_with_past_model_fp32_path) + ) + missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) + + # Export to ONNX + if missing_separate_exports or missing_merged_export: + if args.use_dynamo_export and missing_separate_exports: + logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") + logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") + logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") + logger.warning( + "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" + ) + logger.warning( + "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." ) - logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!") - remove_existing_model(decoder_model_fp32_path) + run_dynamo_export(args, l_config, llama) + elif args.no_merged: + run_torchscript_separate_export(args, l_config, llama, rank, world_size) + else: + run_torchscript_merged_export(args, l_config, llama, rank, world_size) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check + + # Set model paths to store FP32 optimized model + decoder_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx" + ) + decoder_with_past_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx" + ) + decoder_merged_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx" + ) + new_paths = [ + decoder_model_fp32_opt_path, + decoder_with_past_model_fp32_opt_path, + decoder_merged_model_fp32_opt_path, + ] + + # Run the optimizer script + logger.info("Optimizing models...") + for orig_path, opt_path in zip(old_paths, new_paths): + if os.path.exists(orig_path): + optimize_export(l_config, input_path=orig_path, output_path=opt_path) + + # Re-assign default FP32 model paths as their optimized versions + decoder_model_fp32_path = decoder_model_fp32_opt_path + decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + logger.info( + f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + ) - logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + # Change precision of exported models from FP32 + if args.precision == Precision.FLOAT16: + new_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) + + elif args.precision == Precision.INT8: + decoder_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" + ) + decoder_with_past_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" + ) + decoder_merged_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" + ) + new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] + + if args.quantization_method == "smooth_quant": + if not args.no_merged: + logger.error("SmoothQuant must be used on separately exported models") + else: + logger.info( + f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + ) + smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) + + elif args.quantization_method == "quantize_dynamic": + logger.warning( + "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + ) - else: - raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - - elif args.precision == Precision.INT4: - if args.execution_provider != "cpu": - old_paths = convert_to_float16(args, l_config, old_paths) - - decoder_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int4.onnx") - decoder_with_past_model_int4_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_int4.onnx" - ) - decoder_merged_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int4.onnx") - new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] - - for fp_path, int4_path in zip(old_paths, new_paths): - if os.path.exists(fp_path): - model = onnx.load_model(fp_path, load_external_data=True) - quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) - quant.process() - quant.model.save_model_to_file(int4_path, use_external_data_format=True) - del model - del quant - logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") - remove_existing_model(fp_path) + logger.info("Quantizing to int8...") + for fp32_path, int8_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + ort_quantization.quantize_dynamic( + fp32_path, + int8_path, + op_types_to_quantize=["MatMul", "Gemm", "Gather"] + if args.quantize_embedding_layer + else ["MatMul", "Gemm"], + per_channel=args.quantize_per_channel, + reduce_range=args.quantize_reduce_range, + use_external_data_format=True, + extra_options={"MatMulConstBOnly": True}, + ) + logger.info( + f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" + ) + remove_existing_model(decoder_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + else: + raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") + + elif args.precision == Precision.INT4: + if args.execution_provider != "cpu": + old_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) + + decoder_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" + ) + decoder_with_past_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" + ) + decoder_merged_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" + ) + new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + for fp_path, int4_path in zip(old_paths, new_paths): + if os.path.exists(fp_path): + model = onnx.load_model(fp_path, load_external_data=True) + quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) + quant.process() + quant.model.save_model_to_file(int4_path, use_external_data_format=True) + del model + del quant + logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + remove_existing_model(fp_path) + barrier() logger.info("Verifying parity on all ONNX models created") @@ -824,7 +920,12 @@ def main(): # Verify parity on all saved ONNX models for filename in os.listdir(args.output): - if ".data" in filename or ".onnx" not in filename: + if ( + ".data" in filename + or ".onnx" not in filename + or args.precision not in filename + or f"rank_{rank}" not in filename + ): continue parity_cmd = [ @@ -834,10 +935,10 @@ def main(): os.path.join(args.output, filename), "-ep", args.execution_provider, - "-id", - args.device_id, "-fp", args.precision, + "--cache_dir", + args.cache_dir, ] if "with_past" in filename: parity_cmd.append("--use_past_kv") @@ -845,6 +946,7 @@ def main(): parity_cmd.append("--merged") try: + logger.debug(f"check parity with cmd: {parity_cmd}") parity_check(parity_cmd) except Exception as e: logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py new file mode 100644 index 0000000000000..50b0669d6d83a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -0,0 +1,45 @@ +import os + +import torch.distributed as dist + +comm = None + + +def init_dist(): + if "LOCAL_RANK" in os.environ: + int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank) + elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: + from mpi4py import MPI + + comm = MPI.COMM_WORLD # noqa: F841 + + int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank) + else: + # don't need to do init for single process + pass + + +def get_rank(): + return comm.Get_rank() if comm is not None else 0 + + +def get_size(): + return comm.Get_size() if comm is not None else 1 + + +def barrier(): + if comm is not None: + comm.Barrier() + + +def print_out(*args): + if get_rank() == 0: + print(*args) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index f7a1b05249abf..6530eead55f03 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -66,12 +66,13 @@ def get_sample_with_past_kv_inputs( use_fp16: bool = False, engine: str = "pt", return_dict: bool = False, + world_size: int = 1, ): input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64) attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64) # position_ids is of shape (batch_size, 1) position_ids = get_position_ids(attention_mask, use_past_kv=True) - past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) @@ -123,12 +124,13 @@ def get_merged_sample_with_past_kv_inputs( use_fp16: bool = False, engine: str = "pt", return_dict: bool = False, + world_size: int = 1, ): input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) @@ -220,8 +222,8 @@ def get_msft_sample_inputs( # Create past_key_values # Each is of shape (batch_size, num_heads, past_sequence_length, head_size) -def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool): - num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads +def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1): + num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index c1c5d3c412f2a..42581caf3bb9e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -6,7 +6,7 @@ import numpy as np import torch -from benchmark_helper import setup_logger +from dist_settings import get_rank, get_size from llama_inputs import ( add_io_bindings, convert_inputs_for_ort, @@ -14,9 +14,11 @@ get_sample_inputs, get_sample_with_past_kv_inputs, ) +from llama_torch import setup_torch_model from transformers import LlamaConfig, LlamaForCausalLM import onnxruntime as ort +from onnxruntime.transformers.benchmark_helper import setup_logger logger = logging.getLogger("") @@ -30,6 +32,7 @@ def get_sequence_lengths(args: argparse.Namespace): def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity + world_size = get_size() batch_size = 2 past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args) @@ -43,10 +46,17 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): max_seq_len=max_sequence_length, use_fp16=args.use_fp16, return_dict=True, + world_size=world_size, ) elif args.use_past_kv: inputs = get_sample_with_past_kv_inputs( - config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True + config, + args.device, + batch_size, + sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + world_size=world_size, ) else: inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) @@ -66,6 +76,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama torch.cuda.synchronize() end_time = time.time() logger.info(f"PyTorch took {end_time - start_time} s") + del pt_model # Run inference with ORT past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) @@ -76,12 +87,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, device=args.execution_provider, - device_id=int(args.device_id), + device_id=int(args.rank), ) ep = f"{args.execution_provider.upper()}ExecutionProvider" if ep == "CUDAExecutionProvider": - ep = (ep, {"device_id": args.device_id}) + ep = (ep, {"device_id": args.rank}) ort_model = ort.InferenceSession( args.onnx_model_path, sess_options=ort.SessionOptions(), @@ -91,7 +102,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama # Add IO bindings for non-CPU execution providers if args.execution_provider != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings( - ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues + ort_model, inputs, args.execution_provider, int(args.rank), kv_cache_ortvalues ) io_binding.synchronize_inputs() @@ -101,6 +112,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama end_time = time.time() ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits + del ort_model else: start_time = time.time() @@ -155,15 +167,6 @@ def get_args(argv: List[str]): help="Execution provider to verify parity with", ) - parser.add_argument( - "-id", - "--device-id", - required=False, - type=str, - default="0", - help="Device ID for GPUs", - ) - parser.add_argument( "-v", "--verbose", @@ -195,6 +198,14 @@ def get_args(argv: List[str]): help="Precision of model", ) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + args = parser.parse_args() if argv == [] else parser.parse_args(argv) # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models @@ -210,21 +221,23 @@ def main(argv: List[str] = []): # noqa: B006 args = get_args(argv) setup_logger(args.verbose) logger.info(f"Arguments: {args}") + rank = get_rank() # Load model and config setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 - setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{args.device_id}") # noqa: B010 + args.rank = rank + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 setattr(args, "device", torch.device(args.device_name)) # noqa: B010 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory - config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) - llama = LlamaForCausalLM.from_pretrained( + config, llama = setup_torch_model( + args, location, + use_auth_token, torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), - use_auth_token=use_auth_token, - use_cache=True, - ).to(args.device) + device=args.device, + ) kv_cache_ortvalues = {} if not args.merged: diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py new file mode 100644 index 0000000000000..cf6406dde5be0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -0,0 +1,38 @@ +import logging +import os + +import torch +from dist_settings import barrier, get_rank, get_size +from transformers import LlamaConfig, LlamaForCausalLM + +logger = logging.getLogger("") + + +def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, device=None): + world_size = get_size() + logger.info(f"world_size: {world_size}") + rank = get_rank() + barrier() + + if not os.path.exists(args.cache_dir): + os.makedirs(args.cache_dir, exist_ok=True) + + for i in range(world_size): + if i == rank % (world_size): + l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) + l_config.use_cache = True + llama = LlamaForCausalLM.from_pretrained( + location, + use_auth_token=use_auth_token, + config=l_config, + torch_dtype=torch_dtype, + cache_dir=args.cache_dir, + ) + if world_size > 1: + llama.parallel_model() + if device: + llama.to(device) + llama.eval() + llama.requires_grad_(False) + barrier() + return l_config, llama diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt new file mode 100644 index 0000000000000..572cfdb71be4a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt @@ -0,0 +1,4 @@ +-r requirements.txt +git+https://github.com/frankdongms/transformers.git@frdong/shard_llama +mpi4py +psutil \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index e9c24ed3eb09b..392f2f948968e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -337,6 +337,18 @@ def match_parent_paths(self, node, paths, output_name_to_node): return i, matched, return_indice return -1, None, None + def match_parent_paths_all(self, node, paths, output_name_to_node): + match_i, matches, return_indices = [], [], [] + for i, path in enumerate(paths): + assert isinstance(path, (List, Tuple)) + return_indice = [] + matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) + if matched: + match_i.append(i) + matches.append(matched) + return_indices.append(return_indice) + return match_i, matches, return_indices + def match_parent_path( self, node, diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 04351cd6e6782..319fed87dc9eb 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -10,7 +10,10 @@ # license information. # ------------------------------------------------------------------------- import math +import os +import platform import random +import unittest import numpy import torch @@ -22,6 +25,8 @@ torch.manual_seed(0) +pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + class Formats: BSNH = 0 @@ -159,7 +164,7 @@ def create_multihead_attention_graph(config): return model.SerializeToString() -def create_group_query_attention_graph_no_past(config, causal=False): +def create_group_query_attention_graph_no_past(config, causal=False, present_kv_format=Formats.BSNH): nodes = [ helper.make_node( "GroupQueryAttention", @@ -168,11 +173,12 @@ def create_group_query_attention_graph_no_past(config, causal=False): "key", "value", ], - ["output"], + ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, unidirectional=1 if causal else 0, + is_past_bsnh=1 if present_kv_format == Formats.BSNH else 0, domain="com.microsoft", ), ] @@ -213,6 +219,26 @@ def create_group_query_attention_graph_no_past(config, causal=False): TensorProto.FLOAT16, [config.batch_size, config.sequence_length, config.num_heads * config.head_size], ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), ] graph = helper.make_graph( @@ -514,7 +540,6 @@ def generate_token_offset(cu_seqlens, max_seqlen): return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) -# TODO(aciddelgado): rename def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False): onnx_model_str = create_packed_multihead_attention_graph(config) qkv_unpad = torch.swapdims(qkv_unpad, 1, 2) @@ -548,8 +573,8 @@ def mha_func(q, k, v, config): return output -def gqa_no_past_func(q, k, v, config, causal=True): - onnx_model_str = create_group_query_attention_graph_no_past(config, causal) +def gqa_no_past_func(q, k, v, config, causal=True, present_kv_format=Formats.BSNH): + onnx_model_str = create_group_query_attention_graph_no_past(config, causal, present_kv_format=present_kv_format) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) @@ -560,7 +585,7 @@ def gqa_no_past_func(q, k, v, config, causal=True): } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) - ort_output = ort_session.run(None, ort_inputs) + ort_output, _, _ = ort_session.run(None, ort_inputs) ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) return output @@ -689,17 +714,12 @@ def attention_ref( if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - # causal_mask = torch.triu( - # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 - # ) causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) if causal: # Some rows are completely masked out so we fill them with zero instead of NaN attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) - # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling - # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: @@ -1072,12 +1092,6 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(present_k[0, 0, config.past_sequence_length, :10]) - # print(k_cache_ref[0, 0, config.past_sequence_length, :10]) - # print(k_cache_ref.shape) - - # print(present_k - k_cache_ref.detach().cpu().numpy()) - # Make sure past-present buffer updating correctly if past_format == Formats.BSNH: assert numpy.allclose( @@ -1141,84 +1155,185 @@ def parity_check_gqa_past_no_buff( ) +class TestMHA(unittest.TestCase): + def test_packed_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST PACKED MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048] + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s, 0, n, n, h) + parity_check_mha(config, True) + + def test_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ] + ) + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s, s2 in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s2, 0, n, n, h) + parity_check_mha(config, False) + + +class TestGQA(unittest.TestCase): + def test_gqa_no_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + torch.manual_seed(69) + print("-------- TEST GQA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1024, 1024), + (1023, 1024), + (2048, 2048), + ] + ) + num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + if major < 5 or (major == 5 and minor < 3): + return + print("------- MEMORY EFFICIENT ATTENTION ---------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION --------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + + def test_gqa_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + if major < 5 or (major == 5 and minor < 3): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- TEST GQA PAST ---------") + print("-------- MEMORY EFFICEINT --------") + batches = [2] if pipeline_mode else [1, 2] + seqs = ( + [(1, 128), (3, 1024), (64, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + (1, 128 * 512), + (16, 128 * 512), + (128, 128), + ] + ) + num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + + if __name__ == "__main__": - print("-------- TEST PACKED MHA ---------") - for b in [5]: - for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s, 0, n, n, h) - parity_check_mha(config, True) - print("-------- TEST MHA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s2, 0, n, n, h) - parity_check_mha(config, False) - print("-------- TEST GQA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True, False]: - config = Config(b, s, s2, 0, n, n2, h) - parity_check_gqa_no_past(config, causal=causal) - print("-------- TEST GQA PAST ---------") - random.seed(69) - for b in [2]: - for s, s2 in [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 512), - (16, 128 * 512), - (128, 128), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True]: - for past_kv_format in [Formats.BNSH, Formats.BSNH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py index fedba2a25dfc2..373ad86ced1a7 100644 --- a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -96,7 +96,7 @@ def create_inputs_and_outputs(self, model_type: str): helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size), ] - if model_type in {"past", "merged", "llama2_msft"}: + if model_type in {"past", "merged", "llama2_msft", "70b_distributed_merged"}: inputs.extend( [ helper.make_tensor_value_info( @@ -164,14 +164,14 @@ def get_first_rope_input(node_type: str): if is_fused or model_type == "llama2_msft": # q_out/k_out return f"{node_type}_out" - if model_type in {"no_past", "past", "merged"}: + if model_type in {"no_past", "past", "merged", "70b_distributed_merged"}: if node_type == "k": return "k_before_rope" return "q_before_rope" return "" def get_first_rope_output(node_type: str): - if is_fused or model_type in {"llama2_msft", "past", "merged"}: + if is_fused or model_type in {"llama2_msft", "past", "merged", "70b_distributed_merged"}: if node_type == "q": return "q_rope" return "k_rope" @@ -295,23 +295,225 @@ def create_k_path_hf(self, model_type: str): ) k_nodes = [reshape_k_node, transpose_k_1_node] - if model_type in {"past", "merged"}: + if model_type == "70b_distributed_merged": concat_k_node = helper.make_node( "Concat", inputs=["past_key", "k_rope"], outputs=["present_key"], axis=2, ) - k_nodes.append(concat_k_node) + shape_k1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k1_out"], name="Shape_k1") + shape_k2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k2_out"], name="Shape_k2") + shape_k3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k3_out"], name="Shape_k3") + shape_k4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k4_out"], name="Shape_k4") + + gather_k_1 = helper.make_node( + "Gather", + inputs=["shape_k1_out", "one"], + outputs=["gather_k1_out"], + name="Gather_k_1", + axis=0, + ) + gather_k_2 = helper.make_node( + "Gather", + inputs=["shape_k2_out", "one"], + outputs=["gather_k2_out"], + name="Gather_k_2", + axis=0, + ) + gather_k_3 = helper.make_node( + "Gather", + inputs=["shape_k3_out", "one"], + outputs=["gather_k3_out"], + name="Gather_k_3", + axis=0, + ) + gather_k_4 = helper.make_node( + "Gather", + inputs=["shape_k4_out", "one"], + outputs=["gather_k4_out"], + name="Gather_k_4", + axis=0, + ) - transpose_k_2_node = helper.make_node( - "Transpose", - inputs=["present_key"], - outputs=["k"], - name="Transpose_k_2", - perm=[0, 1, 3, 2], - ) - return k_nodes + [transpose_k_2_node] # noqa: RUF005 + unsqueeze_k_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_k1_out"], + name="Unsqueeze_k1", + ) + unsqueeze_k_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k2_out"], + name="Unsqueeze_k2", + ) + unsqueeze_k_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_k2_out", "zero"], + outputs=["unsqueeze_k3_out"], + name="Unsqueeze_k3", + ) + unsqueeze_k_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k4_out"], + name="Unsqueeze_k4", + ) + unsqueeze_k_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k5_out"], + name="Unsqueeze_k5", + ) + + concat_k_2 = helper.make_node( + "Concat", + inputs=["unsqueeze_k2_out", "unsqueeze_k3_out", "One", "unsqueeze_k4_out", "unsqueeze_k5_out"], + outputs=["concat_k2_ouot"], + name="Concat_k2", + axis=0, + ) + reshape_k_2 = helper.make_node( + "Reshape", + inputs=["concat_k2_ouot", "One"], + outputs=["reshape_k2_out"], + name="Reshape_k_2", + ) + shape_k5 = helper.make_node("Shape", inputs=["reshape_k2_out"], outputs=["shape_k5_out"], name="Shape_k5") + constant_of_shape_k_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_k5_out"], + outputs=["constant_of_shape_k1_out"], + name="ConstantOfShape_k1", + ) + mul_k_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_k1_out", "One"], + outputs=["mul_k1_out"], + name="mul_k1", + ) + equal_k_1 = helper.make_node( + "Equal", + inputs=["reshape_k2_out", "mul_k1_out"], + outputs=["equal_k_1_out"], + name="equal_k1", + ) + where_k_1 = helper.make_node( + "Where", + inputs=["equal_k_1_out", "constant_of_shape_k1_out", "reshape_k2_out"], + outputs=["where_k_1_out"], + name="where_k1", + ) + unsqueeze_k_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k6_out"], + name="Unsqueeze_k6", + ) + mul_k_2 = helper.make_node( + "Mul", + inputs=["gather_k2_out", "One"], + outputs=["mul_k2_out"], + name="mul_k2", + ) + unsqueeze_k_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_k2_out", "zero"], + outputs=["unsqueeze_k7_out"], + name="Unsqueeze_k7", + ) + unsqueeze_k_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k8_out"], + name="Unsqueeze_k8", + ) + unsqueeze_k_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k9_out"], + name="Unsqueeze_k9", + ) + concat_k_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_k6_out", "unsqueeze_k7_out", "unsqueeze_k8_out", "unsqueeze_k9_out"], + outputs=["concat_k3_out"], + name="Concat_k3", + axis=0, + ) + expand_k_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_k1_out", "where_k_1_out"], + outputs=["expand_k1_out"], + name="expand_k1", + ) + reshape_k_3 = helper.make_node( + "Reshape", + inputs=["expand_k1_out", "concat_k3_out"], + outputs=["reshape_k3_out"], + name="Reshape_k_3", + ) + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["reshape_k3_out"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + + k_nodes_for_70b_model = [ + concat_k_node, + shape_k1, + shape_k2, + shape_k3, + shape_k4, + gather_k_1, + gather_k_2, + gather_k_3, + gather_k_4, + unsqueeze_k_1, + unsqueeze_k_2, + unsqueeze_k_3, + unsqueeze_k_4, + unsqueeze_k_5, + concat_k_2, + reshape_k_2, + shape_k5, + constant_of_shape_k_1, + mul_k_1, + equal_k_1, + where_k_1, + unsqueeze_k_6, + mul_k_2, + unsqueeze_k_7, + unsqueeze_k_8, + unsqueeze_k_9, + concat_k_3, + expand_k_1, + reshape_k_3, + transpose_k_2_node, + ] + k_nodes.extend(k_nodes_for_70b_model) + return k_nodes + else: + if model_type in {"past", "merged"}: + concat_k_node = helper.make_node( + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) + k_nodes.append(concat_k_node) + + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + return k_nodes + [transpose_k_2_node] # noqa: RUF005 def create_k_path(self, model_type: str): if model_type == "llama2_msft": @@ -505,7 +707,7 @@ def create_v_path(self, model_type: str): if model_type == "no_past": return v_nodes - if model_type in {"past", "merged"}: + if model_type in {"past", "merged", "70b_distributed_merged"}: concat_v_node = helper.make_node( "Concat", inputs=["past_value", "transpose_v_1_out"], @@ -513,7 +715,194 @@ def create_v_path(self, model_type: str): name="Concat_v", axis=2, ) - return v_nodes + [concat_v_node] # noqa: RUF005 + + if model_type != "70b_distributed_merged": + return v_nodes + [concat_v_node] # noqa: RUF005 + + shape_v1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_1_out"], name="Shape_v1") + shape_v2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_2_out"], name="Shape_v2") + shape_v3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_3_out"], name="Shape_v3") + shape_v4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_4_out"], name="Shape_v4") + gather_v_1 = helper.make_node( + "Gather", + inputs=["shape_1_out", "one"], + outputs=["gather_1_out"], + name="Gather_v1", + axis=0, + ) + gather_v_2 = helper.make_node( + "Gather", + inputs=["shape_2_out", "one"], + outputs=["gather_2_out"], + name="Gather_v2", + axis=0, + ) + gather_v_3 = helper.make_node( + "Gather", + inputs=["shape_3_out", "one"], + outputs=["gather_3_out"], + name="Gather_v3", + axis=0, + ) + gather_v_4 = helper.make_node( + "Gather", + inputs=["shape_4_out", "one"], + outputs=["gather_4_out"], + name="Gather_v4", + axis=0, + ) + unsqueeze_v_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_v1_out"], + name="Unsqueeze_v1", + ) + unsqueeze_v_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v2_out"], + name="Unsqueeze_v2", + ) + unsqueeze_v_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_2_out", "zero"], + outputs=["unsqueeze_v3_out"], + name="Unsqueeze_v3", + ) + unsqueeze_v_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v4_out"], + name="Unsqueeze_v4", + ) + unsqueeze_v_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v5_out"], + name="Unsqueeze_v5", + ) + concat_v_2 = helper.make_node( + "Concat", + inputs=["unsqueeze_v2_out", "unsqueeze_v3_out", "One", "unsqueeze_v4_out", "unsqueeze_v5_out"], + outputs=["concat_v2_ouot"], + name="Concat_v2", + axis=0, + ) + reshape_v_2 = helper.make_node( + "Reshape", + inputs=["concat_v2_ouot", "One"], + outputs=["reshape_v2_out"], + name="Reshape_v2", + ) + shape_v5 = helper.make_node("Shape", inputs=["reshape_v2_out"], outputs=["shape_5_out"], name="Shape_v5") + constant_of_shape_v_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_5_out"], + outputs=["constant_of_shape_v1_out"], + name="ConstantOfShape_v1", + ) + mul_v_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_v1_out", "One"], + outputs=["mul_v1_out"], + name="mul_v1", + ) + equal_v_1 = helper.make_node( + "Equal", + inputs=["reshape_v2_out", "mul_v1_out"], + outputs=["equal_v_1_out"], + name="equal_v1", + ) + where_v_1 = helper.make_node( + "Where", + inputs=["equal_v_1_out", "constant_of_shape_v1_out", "reshape_v2_out"], + outputs=["where_v_1_out"], + name="where_v1", + ) + unsqueeze_v_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v6_out"], + name="Unsqueeze_v6", + ) + mul_v_2 = helper.make_node( + "Mul", + inputs=["gather_2_out", "One"], + outputs=["mul_v2_out"], + name="mul_v2", + ) + unsqueeze_v_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_v2_out", "zero"], + outputs=["unsqueeze_v7_out"], + name="Unsqueeze_v7", + ) + unsqueeze_v_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v8_out"], + name="Unsqueeze_v8", + ) + unsqueeze_v_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v9_out"], + name="Unsqueeze_v9", + ) + concat_v_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_v6_out", "unsqueeze_v7_out", "unsqueeze_v8_out", "unsqueeze_v9_out"], + outputs=["concat_v3_out"], + name="Concat_v3", + axis=0, + ) + expand_v_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_v1_out", "where_v_1_out"], + outputs=["expand_v1_out"], + name="expand_v1", + ) + reshape_v_3 = helper.make_node( + "Reshape", + inputs=["expand_v1_out", "concat_v3_out"], + outputs=["reshape_v3_out"], + name="Reshape_v3", + ) + + v_nodes_for_70b_model = [ + concat_v_node, + shape_v1, + shape_v2, + shape_v3, + shape_v4, + gather_v_1, + gather_v_2, + gather_v_3, + gather_v_4, + unsqueeze_v_1, + unsqueeze_v_2, + unsqueeze_v_3, + unsqueeze_v_4, + unsqueeze_v_5, + concat_v_2, + reshape_v_2, + shape_v5, + constant_of_shape_v_1, + mul_v_1, + equal_v_1, + where_v_1, + unsqueeze_v_6, + mul_v_2, + unsqueeze_v_7, + unsqueeze_v_8, + unsqueeze_v_9, + concat_v_3, + expand_v_1, + reshape_v_3, + ] + v_nodes.extend(v_nodes_for_70b_model) + + return v_nodes # Create extra nodes for `position_ids` unsqueeze_v_node = helper.make_node( @@ -672,7 +1061,28 @@ def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: List[Nod return extra_nodes - def create_end_nodes(self): + def create_end_nodes(self, model_type): + if model_type == "70b_distributed_merged": + matmul_o_node = helper.make_node( + "MatMul", + inputs=["attn_output", "o_weight"], + outputs=["output_proj"], + name="MatMul_o_proj", + ) + all_reduce = helper.make_node( + "AllReduce", + inputs=["output_proj"], + outputs=["allreduce_proj"], + name="allreduce_proj", + ) + end_node = helper.make_node( + "Add", + inputs=["zero", "allreduce_proj"], + outputs=["output_0"], + name="Add_normalize_node", + ) + return [matmul_o_node, all_reduce, end_node] + matmul_o_node = helper.make_node( "MatMul", inputs=["attn_output", "o_weight"], @@ -711,7 +1121,7 @@ def create_fused_model(self, model_type: str, interleaved: bool, initializers: L num_heads=self.num_heads, ) - end_nodes = self.create_end_nodes() + end_nodes = self.create_end_nodes(model_type) graph = helper.make_graph( nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes, @@ -740,7 +1150,7 @@ def create_test_model(self, model_type: str, interleaved: bool, initializers: Li reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes)) extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes) - end_nodes = self.create_end_nodes() + end_nodes = self.create_end_nodes(model_type) first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes @@ -790,6 +1200,11 @@ def test_hf_decoder_merged_model(self): interleaved = False self.check_models(model_type, interleaved) + def test_hf_70b_distributed_decoder_merged_model(self): + model_type = "70b_distributed_merged" + interleaved = False + self.check_models(model_type, interleaved) + if __name__ == "__main__": unittest.main()