From 4523b3d06c19018e72354f7c7837658bbae01966 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 16 Jan 2025 18:54:44 -0800 Subject: [PATCH 01/17] profile init code Signed-off-by: Liqun Fu --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 39 ++++++++++++---- .../cpu/bert/group_query_attention.cc | 46 ++++++++++++++++++- .../test/python/transformers/test_gqa_cpu.py | 6 +++ 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index ccaeb6654e286..a73a1a4d38b8b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -87,18 +87,36 @@ class GQAAttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, - sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); - + { + std::chrono::high_resolution_clock::time_point time_point; + if (profiler_->IsEnabled()) { + time_point = profiler_->Start(); + } + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, + sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + if (profiler_->IsEnabled()) { + std::string eventName = context->GetNodeName() + "_" + "ComputeAttentionProbs"; + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); + } + } // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, - seqlens_k->Data(), - batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, - hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, tp, allocator); - + { + std::chrono::high_resolution_clock::time_point time_point; + if (profiler_->IsEnabled()) { + time_point = profiler_->Start(); + } + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, + seqlens_k->Data(), + batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, + hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, + is_prompt, tp, allocator); + if (profiler_->IsEnabled()) { + std::string eventName = context->GetNodeName() + "_" + "ComputeVxAttentionScore"; + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); + } + } return Status::OK(); } @@ -123,6 +141,7 @@ class GQAAttentionBase { const bool is_prompt, // whether it is prompt ThreadPool* tp, // thread pool AllocatorPtr allocator) const { // allocator for temporary buffer + const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 8f662cd388c6d..003df4680cf86 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -16,6 +16,32 @@ #include #include +// https://github.com/microsoft/onnxruntime/blob/b9493adbe88c4681fcae71774ec3685d1390bd46/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +#include +#include "core/common/profiler.h" +class ProfilerWrapper { + public: + ProfilerWrapper() { + profiler_ = std::make_unique(); + profiler_->StartProfiling("profile.json"); + } + + ~ProfilerWrapper() { + if (profiler_) { + profiler_->EndProfiling(); + } + } + + onnxruntime::profiling::Profiler* operator->() { + return profiler_.get(); + } + + private: + std::unique_ptr profiler_; +}; + +static ProfilerWrapper profiler_; + using onnxruntime::concurrency::ThreadPool; namespace onnxruntime { @@ -112,6 +138,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { T* q_rotary = Q.GetMutable()->MutableData(); T* k_rotary = packed_qkv ? nullptr : K.GetMutable()->MutableData(); if (do_rotary_) { + std::chrono::high_resolution_clock::time_point time_point; + if (profiler_->IsEnabled()) { + time_point = profiler_->Start(); + } + // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; @@ -189,13 +220,26 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { v_input, v_rotary)); } + if (profiler_->IsEnabled()) { + std::string eventName = this->Node().Name() + "_" + "rotary"; + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); + } } ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + std::chrono::high_resolution_clock::time_point time_point; + if (profiler_->IsEnabled()) { + time_point = profiler_->Start(); + } // Compute the attention score and apply the score to V - return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), + auto ret = ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); + if (profiler_->IsEnabled()) { + std::string eventName = this->Node().Name() + "_" + "ApplyAttention"; + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); + } + return ret; } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 77b4b326bf645..b9a81f0c5dc23 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -714,6 +714,7 @@ def gqa_prompt_func( "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } sess_options = SessionOptions() + sess_options.enable_profiling = True ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() if new_k is not None: @@ -747,6 +748,11 @@ def gqa_prompt_func( ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) + + if sess_options.enable_profiling: + profile_file = ort_session.end_profiling() + print(f"Profiling data saved to: {profile_file}") + return output, present_k, present_v else: ort_inputs = { From 59e276091a7899cd76aa2f3d30694d73e1862bc4 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Thu, 6 Feb 2025 01:52:25 +0000 Subject: [PATCH 02/17] from patch Signed-off-by: liqunfu --- cmake/onnxruntime_mlas.cmake | 7 + onnxruntime/core/mlas/lib/mlasi.h | 1 + onnxruntime/core/mlas/lib/platform.cpp | 1 + .../mlas/lib/rotary_embedding_kernel_avx2.cpp | 30 ++++ .../mlas/lib/rotary_embedding_kernel_avx2.h | 37 ++++ .../lib/rotary_embedding_kernel_avx2_fp32.cpp | 166 ++++++++++++++++++ 6 files changed, 242 insertions(+) create mode 100644 onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp create mode 100644 onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h create mode 100644 onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index ed3ad89247975..1f0d5d553ff8a 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -177,6 +177,9 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/dgemm.cpp ${mlas_platform_srcs_avx} ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2_fp32.cpp ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp @@ -222,6 +225,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm ) + if(MSVC_VERSION GREATER_EQUAL 1933) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm @@ -586,6 +590,9 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2_fp32.cpp ) if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) set(mlas_platform_srcs_avx2 diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 56fad6bb3412a..978d4f5db0fc4 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1058,6 +1058,7 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; // struct MLAS_ROPE_DISPATCH; extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2; // // half gemm dispatch structure diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 026a954bbc6c2..501e2d5f8c8b8 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -401,6 +401,7 @@ Return Value: this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; + this->RopeDispatch = &MlasRopeDispatchAvx2; // diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp new file mode 100644 index 0000000000000..a1dec92fe384a --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp @@ -0,0 +1,30 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2.cpp + +Abstract: + + This module implements the rotary embedding kernels for AVX2 supported h/w. + +--*/ + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_avx2.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2 = []() { + MLAS_ROPE_DISPATCH d; + +#if defined(MLAS_TARGET_AMD64) + d.SRope = rope_avx2::RopeKernel_Avx2; +#endif + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h new file mode 100644 index 0000000000000..69448a033f829 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2.h + +Abstract: + + This module includes function declarations and common helper functions for + rotary embedding on for AVX2 enabled h/w. + +--*/ + +#pragma once + + + +#include "mlasi.h" + +namespace rope_avx2 { + +// Rotary embedding kernel for FP32. Embed one hidden state vector. +void +RopeKernel_Avx2( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output +); + +} // namespace rope_avx2 diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp new file mode 100644 index 0000000000000..555831e6eaf17 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp @@ -0,0 +1,166 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2_fp32.cpp + +Abstract: + + This module implements the fp32 rotary embedding kernels using AVX2. + +--*/ + +#include + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_avx2.h" + +namespace rope_avx2 { + +namespace { + +typedef __m256 float32x8_t; + +template +void +RopeKernel_Avx2_Impl( + const float* input, + const float* sin, + const float* cos, + size_t dim, + float* output +); + +template <> +void +RopeKernel_Avx2_Impl( + const float* input, + const float* sin, + const float* cos, + size_t dim, + float* output +) { + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + for (; i + 7 < half_dim; i += 8, j += 8) { + float32x8_t real = _mm256_loadu_ps(input + i); + float32x8_t imag = _mm256_loadu_ps(input + j); + float32x8_t sin_val = _mm256_loadu_ps(sin + i); + float32x8_t cos_val = _mm256_loadu_ps(cos + i); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into non interleaved format + _mm256_store_ps(output + i, real_out); + _mm256_store_ps(output + j, imag_out); + } + if (half_dim - i != 0) { + size_t rem = half_dim - i; + static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + const __m256i mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - rem)); + //Use a mask to load the remaining input values + float32x8_t real = _mm256_maskload_ps(input, mask); + float32x8_t imag = _mm256_maskload_ps(input + j, mask); + float32x8_t sin_val = _mm256_maskload_ps(sin + i, mask); + float32x8_t cos_val = _mm256_maskload_ps(cos + i, mask); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into non interleaved format + _mm256_maskstore_ps(output + i, mask, real_out); + _mm256_maskstore_ps(output + j, mask, imag_out); + } +} + +template <> +void +RopeKernel_Avx2_Impl( + const float* input, + const float* sin, + const float* cos, + size_t dim, + float* output +) { + size_t i = 0; + for (; i + 15 < dim; i += 16) { + float32x8_t x0 = _mm256_loadu_ps(input + i); + float32x8_t x1 = _mm256_loadu_ps(input + i + 8); + //Load imaginary and real values to seperate non-interleaved vectors + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + float32x8_t sin_val = _mm256_loadu_ps(sin + i); + float32x8_t cos_val = _mm256_loadu_ps(cos + i); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + _mm256_store_ps(output + i, y0); + _mm256_store_ps(output + i + 8, y1); + } + if (dim - i != 0) { + size_t rem = dim - i; + static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + const __m256i mask0 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?8:rem))); + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0))); + float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask + float32x8_t x1 = _mm256_maskload_ps(input + i + 8, mask1); //Load the reminder of data using a second mask + //Load imaginary and real values to seperate non-interleaved vectors + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + float32x8_t sin_val = _mm256_loadu_ps(sin + i); + float32x8_t cos_val = _mm256_loadu_ps(cos + i); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + _mm256_maskstore_ps(output + i, mask0, y0); + _mm256_maskstore_ps(output + i + 8, mask1, y1); + } +} + +} // namespace + +void +RopeKernel_Avx2( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output +) { + // real part and imaginary part must be paired + assert(dim % 2 == 0); + const auto* input_impl = reinterpret_cast(input); + const auto* sin_impl = reinterpret_cast(sin); + const auto* cos_impl = reinterpret_cast(cos); + auto* output_impl = reinterpret_cast(output); + + if (interleaved) { + RopeKernel_Avx2_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } else { + RopeKernel_Avx2_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } +} + +} From 35bf5176fe8bf2e227ffbf0116cca54564d83d6d Mon Sep 17 00:00:00 2001 From: liqunfu Date: Fri, 14 Feb 2025 00:24:17 +0000 Subject: [PATCH 03/17] node_name and remove profiler wrapper Signed-off-by: liqunfu --- .../cpu/bert/group_query_attention.cc | 31 ++++++------------- .../test/python/transformers/test_gqa_cpu.py | 18 ++++++++++- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 003df4680cf86..70143816d1620 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -19,28 +19,8 @@ // https://github.com/microsoft/onnxruntime/blob/b9493adbe88c4681fcae71774ec3685d1390bd46/onnxruntime/core/mlas/lib/sqnbitgemm.cpp #include #include "core/common/profiler.h" -class ProfilerWrapper { - public: - ProfilerWrapper() { - profiler_ = std::make_unique(); - profiler_->StartProfiling("profile.json"); - } - - ~ProfilerWrapper() { - if (profiler_) { - profiler_->EndProfiling(); - } - } - - onnxruntime::profiling::Profiler* operator->() { - return profiler_.get(); - } - - private: - std::unique_ptr profiler_; -}; -static ProfilerWrapper profiler_; +static onnxruntime::profiling::Profiler* profiler_ = nullptr; using onnxruntime::concurrency::ThreadPool; @@ -69,6 +49,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) template Status GroupQueryAttention::Compute(OpKernelContext* context) const { + + const std::string node_name = this->Node().Name(); + + // Initialize the profiler_ with a unique log file based on the node name + profiler_ = new onnxruntime::profiling::Profiler(); + profiler_->StartProfiling(node_name + "_log.txt"); + const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); @@ -239,6 +226,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { std::string eventName = this->Node().Name() + "_" + "ApplyAttention"; profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); } + profiler_->EndProfiling(); + delete profiler_; return ret; } } // namespace contrib diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index b9a81f0c5dc23..f42be8ead29ee 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -157,6 +157,7 @@ def create_group_query_attention_graph_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, + node_name=None, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -175,7 +176,7 @@ def create_group_query_attention_graph_prompt( "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], - "GroupQueryAttention_0", + name=node_name, num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, @@ -186,6 +187,7 @@ def create_group_query_attention_graph_prompt( # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", + ), ] @@ -687,6 +689,7 @@ def gqa_prompt_func( rotary_interleaved=False, softcap=0.0, use_smooth_softmax=False, + node_name=None, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -698,6 +701,7 @@ def gqa_prompt_func( packed=new_k is None, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + node_name=node_name, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None @@ -1072,6 +1076,7 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, + node_name=None, rtol=RTOL, atol=ATOL, ): @@ -1191,6 +1196,7 @@ def parity_check_gqa_prompt( # Flash function if packed: + node_name = "packed_" packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) out, present_k, present_v = gqa_prompt_func( packed_qkv, @@ -1208,6 +1214,7 @@ def parity_check_gqa_prompt( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, + node_name=node_name, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1933,6 +1940,14 @@ def test_gqa_no_past(self): for packed in [False, True]: for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: + node_name = ( + ("packed_" if packed else "") + + ("rotary_" if rotary else "") + + ("rotary_interleaved_" if rotary_interleaved else "") + + "softcap_" + str(softcap) + "_" + + "smooth_softmax_" + str(use_smooth_softmax) + "_" + + "b_" + str(b) + "_sq_" + str(sq) + "_skv_" + str(skv) + "_n_" + str(n) + "_n2_" + str(n2) + "_h_" + str(h) + ) config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) past_kv_format = Formats.BNSH all_close = parity_check_gqa_prompt( @@ -1944,6 +1959,7 @@ def test_gqa_no_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + node_name=node_name, ) self.assertTrue(all_close) all_close = parity_check_gqa_prompt_no_buff( From 3964acc63473db0cb4afa116bbe9828b0c6f36af Mon Sep 17 00:00:00 2001 From: liqunfu Date: Sat, 15 Feb 2025 00:41:21 +0000 Subject: [PATCH 04/17] remove profiling code Signed-off-by: liqunfu --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 16 ---------- .../cpu/bert/group_query_attention.cc | 32 ------------------- 2 files changed, 48 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index ba7a8f0e8b7f5..b6882320e6d6a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -89,34 +89,18 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; { - std::chrono::high_resolution_clock::time_point time_point; - if (profiler_->IsEnabled()) { - time_point = profiler_->Start(); - } ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); - if (profiler_->IsEnabled()) { - std::string eventName = context->GetNodeName() + "_" + "ComputeAttentionProbs"; - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); - } } // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; { - std::chrono::high_resolution_clock::time_point time_point; - if (profiler_->IsEnabled()) { - time_point = profiler_->Start(); - } ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); - if (profiler_->IsEnabled()) { - std::string eventName = context->GetNodeName() + "_" + "ComputeVxAttentionScore"; - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); - } } return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 70143816d1620..f09443a0b99bd 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -16,12 +16,6 @@ #include #include -// https://github.com/microsoft/onnxruntime/blob/b9493adbe88c4681fcae71774ec3685d1390bd46/onnxruntime/core/mlas/lib/sqnbitgemm.cpp -#include -#include "core/common/profiler.h" - -static onnxruntime::profiling::Profiler* profiler_ = nullptr; - using onnxruntime::concurrency::ThreadPool; namespace onnxruntime { @@ -49,13 +43,6 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) template Status GroupQueryAttention::Compute(OpKernelContext* context) const { - - const std::string node_name = this->Node().Name(); - - // Initialize the profiler_ with a unique log file based on the node name - profiler_ = new onnxruntime::profiling::Profiler(); - profiler_->StartProfiling(node_name + "_log.txt"); - const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); @@ -125,11 +112,6 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { T* q_rotary = Q.GetMutable()->MutableData(); T* k_rotary = packed_qkv ? nullptr : K.GetMutable()->MutableData(); if (do_rotary_) { - std::chrono::high_resolution_clock::time_point time_point; - if (profiler_->IsEnabled()) { - time_point = profiler_->Start(); - } - // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; @@ -207,27 +189,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { v_input, v_rotary)); } - if (profiler_->IsEnabled()) { - std::string eventName = this->Node().Name() + "_" + "rotary"; - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); - } } ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - std::chrono::high_resolution_clock::time_point time_point; - if (profiler_->IsEnabled()) { - time_point = profiler_->Start(); - } // Compute the attention score and apply the score to V auto ret = ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); - if (profiler_->IsEnabled()) { - std::string eventName = this->Node().Name() + "_" + "ApplyAttention"; - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); - } - profiler_->EndProfiling(); - delete profiler_; return ret; } } // namespace contrib From 40a6854221582e6e217eafb8b560237390aaa476 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Sat, 15 Feb 2025 01:21:49 +0000 Subject: [PATCH 05/17] undo test_gqa_cpu.py Signed-off-by: liqunfu --- .../test/python/transformers/test_gqa_cpu.py | 24 +------------------ 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index f42be8ead29ee..77b4b326bf645 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -157,7 +157,6 @@ def create_group_query_attention_graph_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - node_name=None, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -176,7 +175,7 @@ def create_group_query_attention_graph_prompt( "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], - name=node_name, + "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, @@ -187,7 +186,6 @@ def create_group_query_attention_graph_prompt( # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", - ), ] @@ -689,7 +687,6 @@ def gqa_prompt_func( rotary_interleaved=False, softcap=0.0, use_smooth_softmax=False, - node_name=None, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -701,7 +698,6 @@ def gqa_prompt_func( packed=new_k is None, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - node_name=node_name, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None @@ -718,7 +714,6 @@ def gqa_prompt_func( "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } sess_options = SessionOptions() - sess_options.enable_profiling = True ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() if new_k is not None: @@ -752,11 +747,6 @@ def gqa_prompt_func( ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) - - if sess_options.enable_profiling: - profile_file = ort_session.end_profiling() - print(f"Profiling data saved to: {profile_file}") - return output, present_k, present_v else: ort_inputs = { @@ -1076,7 +1066,6 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - node_name=None, rtol=RTOL, atol=ATOL, ): @@ -1196,7 +1185,6 @@ def parity_check_gqa_prompt( # Flash function if packed: - node_name = "packed_" packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) out, present_k, present_v = gqa_prompt_func( packed_qkv, @@ -1214,7 +1202,6 @@ def parity_check_gqa_prompt( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - node_name=node_name, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1940,14 +1927,6 @@ def test_gqa_no_past(self): for packed in [False, True]: for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: - node_name = ( - ("packed_" if packed else "") + - ("rotary_" if rotary else "") + - ("rotary_interleaved_" if rotary_interleaved else "") + - "softcap_" + str(softcap) + "_" + - "smooth_softmax_" + str(use_smooth_softmax) + "_" + - "b_" + str(b) + "_sq_" + str(sq) + "_skv_" + str(skv) + "_n_" + str(n) + "_n2_" + str(n2) + "_h_" + str(h) - ) config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) past_kv_format = Formats.BNSH all_close = parity_check_gqa_prompt( @@ -1959,7 +1938,6 @@ def test_gqa_no_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - node_name=node_name, ) self.assertTrue(all_close) all_close = parity_check_gqa_prompt_no_buff( From 43bdb444687ea00ce4464f9c0005247f91dd2ff0 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Sat, 15 Feb 2025 01:25:26 +0000 Subject: [PATCH 06/17] lint Signed-off-by: liqunfu --- onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index f09443a0b99bd..c85048089cd8e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -194,8 +194,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); // Compute the attention score and apply the score to V auto ret = ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - past_key, past_value, output, present_k, present_v, - seqlens_k, parameters, allocator, context); + past_key, past_value, output, present_k, present_v, + seqlens_k, parameters, allocator, context); return ret; } } // namespace contrib From 46353d84d7c480f6249519213da88a992fb13fce Mon Sep 17 00:00:00 2001 From: liqunfu Date: Sat, 15 Feb 2025 02:04:55 +0000 Subject: [PATCH 07/17] some edit Signed-off-by: liqunfu --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 21 +++++++------------ .../cpu/bert/group_query_attention.cc | 7 +++---- .../mlas/lib/rotary_embedding_kernel_avx2.cpp | 5 +---- .../lib/rotary_embedding_kernel_avx2_fp32.cpp | 4 ++-- 4 files changed, 14 insertions(+), 23 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index b6882320e6d6a..5c6ff4c211b28 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -88,20 +88,16 @@ class GQAAttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, - sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); - } + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, + sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - { - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, - seqlens_k->Data(), - batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, - hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, tp, allocator); - } + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, + seqlens_k->Data(), + batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, + hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, + is_prompt, tp, allocator); return Status::OK(); } @@ -126,7 +122,6 @@ class GQAAttentionBase { const bool is_prompt, // whether it is prompt ThreadPool* tp, // thread pool AllocatorPtr allocator) const { // allocator for temporary buffer - const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index c85048089cd8e..8f662cd388c6d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -193,10 +193,9 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); // Compute the attention score and apply the score to V - auto ret = ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - past_key, past_value, output, present_k, present_v, - seqlens_k, parameters, allocator, context); - return ret; + return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), + past_key, past_value, output, present_k, present_v, + seqlens_k, parameters, allocator, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp index a1dec92fe384a..7b6c49720853a 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp @@ -22,9 +22,6 @@ Module Name: // const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2 = []() { MLAS_ROPE_DISPATCH d; - -#if defined(MLAS_TARGET_AMD64) - d.SRope = rope_avx2::RopeKernel_Avx2; -#endif + d.SRope = rope_avx2::RopeKernel_Avx2; return d; }(); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp index 555831e6eaf17..5dddb3133735b 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp @@ -89,7 +89,7 @@ RopeKernel_Avx2_Impl( for (; i + 15 < dim; i += 16) { float32x8_t x0 = _mm256_loadu_ps(input + i); float32x8_t x1 = _mm256_loadu_ps(input + i + 8); - //Load imaginary and real values to seperate non-interleaved vectors + //Load imaginary and real values to separate non-interleaved vectors float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); @@ -116,7 +116,7 @@ RopeKernel_Avx2_Impl( const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0))); float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask float32x8_t x1 = _mm256_maskload_ps(input + i + 8, mask1); //Load the reminder of data using a second mask - //Load imaginary and real values to seperate non-interleaved vectors + //Load imaginary and real values to separate non-interleaved vectors float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); From e13ac56d8c8ca1d57643e3a9e02153355925bb52 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Sat, 15 Feb 2025 21:44:24 +0000 Subject: [PATCH 08/17] fix data correctness in interleaved cases Signed-off-by: liqunfu --- onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | 4 +++- .../core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 5c6ff4c211b28..abb24e20a6178 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -90,7 +90,8 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, @@ -98,6 +99,7 @@ class GQAAttentionBase { batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + return Status::OK(); } diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp index 5dddb3133735b..7b32179be2e3d 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp @@ -95,8 +95,8 @@ RopeKernel_Avx2_Impl( __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); - float32x8_t sin_val = _mm256_loadu_ps(sin + i); - float32x8_t cos_val = _mm256_loadu_ps(cos + i); + float32x8_t sin_val = _mm256_loadu_ps(sin + i / 2); + float32x8_t cos_val = _mm256_loadu_ps(cos + i / 2); //Compute Real and Imaginary output values float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); @@ -122,8 +122,8 @@ RopeKernel_Avx2_Impl( __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); - float32x8_t sin_val = _mm256_loadu_ps(sin + i); - float32x8_t cos_val = _mm256_loadu_ps(cos + i); + float32x8_t sin_val = _mm256_loadu_ps(sin + i / 2); + float32x8_t cos_val = _mm256_loadu_ps(cos + i / 2); //Compute Real and Imaginary output values float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); From 9867ee78a8f231b892c4c6e5686869ee798edd7f Mon Sep 17 00:00:00 2001 From: liqunfu Date: Wed, 19 Feb 2025 17:59:40 +0000 Subject: [PATCH 09/17] one more data correctness fix, add MLAS RoPE test to covert all scenarios, clear interleaved attribute Signed-off-by: liqunfu --- cmake/onnxruntime_mlas.cmake | 1 + .../core/graph/contrib_ops/bert_defs.cc | 4 +- onnxruntime/core/mlas/inc/mlas_attention.h | 45 +++++++ .../core/mlas/lib/rotary_embedding.cpp | 16 ++- .../lib/rotary_embedding_kernel_avx2_fp32.cpp | 2 +- onnxruntime/test/mlas/unittest/test_rope.cpp | 127 ++++++++++++++++++ 6 files changed, 188 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/core/mlas/inc/mlas_attention.h create mode 100644 onnxruntime/test/mlas/unittest/test_rope.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 15864a0198161..4fd317b72b2ee 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -51,6 +51,7 @@ target_sources(onnxruntime_mlas PRIVATE ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h ${MLAS_INC_DIR}/mlas_q4.h ${MLAS_INC_DIR}/mlas_qnbit.h + ${MLAS_INC_DIR}/mlas_attention.h ${MLAS_INC_DIR}/mlas.h ) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index d8f8d28a47621..ecc8cb091b1b6 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1344,7 +1344,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::FLOAT, OPTIONAL_VALUE) .Attr("interleaved", - "Rotate using interleaved pattern. Default value is 0 (False).", + "Indicates whether the input has real and imaginary parts interleaved. " + "Default value is 0 (False), meaning the first half of the input consists of real values " + "and the second half consists of imaginary values.", AttributeProto::INT, OPTIONAL_VALUE) .Attr("rotary_embedding_dim", diff --git a/onnxruntime/core/mlas/inc/mlas_attention.h b/onnxruntime/core/mlas/inc/mlas_attention.h new file mode 100644 index 0000000000000..b1cbd246d24fc --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_attention.h @@ -0,0 +1,45 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_attention.h + +Abstract: + + This module contains the public data structures and procedure prototypes + for attention related ops + + +--*/ + +#pragma once + + +#include "mlas.h" +#include "mlas_gemm_postprocessor.h" + +template +void MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +); + +template +void MLASCALL +MlasRotaryEmbedOneRow( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output +); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp index 1f8f7b240694c..0f14d6b645115 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp @@ -16,8 +16,6 @@ Module Name: #include "rotary_embedding.h" -namespace { - template void MLASCALL @@ -55,9 +53,6 @@ MlasRotaryEmbedOneRow_FallBack( } } -} // namespace - - template <> void MLASCALL @@ -99,3 +94,14 @@ MlasRotaryEmbedOneRow( dispatch->HRope(input, sin, cos, dim, interleaved, output); } + +template <> +void MLASCALL +MlasRotaryEmbedOneRow( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output +); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp index 7b32179be2e3d..76be2a23c64a6 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp @@ -63,7 +63,7 @@ RopeKernel_Avx2_Impl( static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; const __m256i mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - rem)); //Use a mask to load the remaining input values - float32x8_t real = _mm256_maskload_ps(input, mask); + float32x8_t real = _mm256_maskload_ps(input + i, mask); float32x8_t imag = _mm256_maskload_ps(input + j, mask); float32x8_t sin_val = _mm256_maskload_ps(sin + i, mask); float32x8_t cos_val = _mm256_maskload_ps(cos + i, mask); diff --git a/onnxruntime/test/mlas/unittest/test_rope.cpp b/onnxruntime/test/mlas/unittest/test_rope.cpp new file mode 100644 index 0000000000000..a365ebb5562a0 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_rope.cpp @@ -0,0 +1,127 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_rope.h + +Abstract: + + Tests for MLAS RoPE. + +--*/ + +#include "test_util.h" +#include "mlas_attention.h" + +class MlasRoPETest : public MlasTestBase { + const float Pi = 2 * std::acos(0.0f); + + public: + void Test(size_t rotary_emb_dim, bool interleaved) { + std::vector input(rotary_emb_dim); + size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; + std::vector sin(table_len); + std::vector cos(table_len); + std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); + + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = static_cast(i + 1); + } + for (size_t i = 0; i < table_len; ++i) { + float theta = (float)i / 1000 * Pi; + sin[i] = std::sin(theta); + cos[i] = std::cos(theta); + } + + // Call the function + MlasRotaryEmbedOneRow_FallBack(&input[0], &sin[0], &cos[0], rotary_emb_dim, interleaved, &output_ref[0]); + MlasRotaryEmbedOneRow(&input[0], &sin[0], &cos[0], rotary_emb_dim, interleaved, &output_impl[0]); + + for (size_t i = 0; i < rotary_emb_dim; i++) { + ASSERT_TRUE(CloseEnough(output_impl[i], output_ref[i])) + << "Expected: " << output_ref[i] << " Actual: " << output_impl[i] << "@[" << i << "], " + << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved; + } + } + + public: +}; + +// +// Short Execute() test helper to register each test separately by all parameters. +// +class RoPEShortExecuteTest : public MlasTestFixture { + public: + explicit RoPEShortExecuteTest(size_t rotary_emb_dim, bool interleaved) + : rotary_emb_dim_(rotary_emb_dim), + interleaved_(interleaved) {} + + void TestBody() override { + MlasTestFixture::mlas_tester->Test(rotary_emb_dim_, interleaved_); + } + + static size_t RegisterSingleTest(size_t rotary_emb_dim, bool interleaved) { + size_t tests_registered = 0; + + std::stringstream ss; + ss << "/rotary_emb_dim" << rotary_emb_dim << "/interleaved" << interleaved; + auto test_name = ss.str(); + + testing::RegisterTest( + "RoPE", + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture* { + return new RoPEShortExecuteTest(rotary_emb_dim, interleaved); + }); + + tests_registered += 1; + + return tests_registered; + } + + static size_t RegisterShortExecuteTests() { + size_t tests_registered = 0; + tests_registered += RegisterSingleTest(6, false); + tests_registered += RegisterSingleTest(6, true); + tests_registered += RegisterSingleTest(16, false); + tests_registered += RegisterSingleTest(16, true); + tests_registered += RegisterSingleTest(24, false); + tests_registered += RegisterSingleTest(24, true); + tests_registered += RegisterSingleTest(32, false); + tests_registered += RegisterSingleTest(32, true); + tests_registered += RegisterSingleTest(42, false); + tests_registered += RegisterSingleTest(42, true); + tests_registered += RegisterSingleTest(64, false); + tests_registered += RegisterSingleTest(64, true); + tests_registered += RegisterSingleTest(70, false); + tests_registered += RegisterSingleTest(70, true); + return tests_registered; + } + + private: + size_t rotary_emb_dim_; + bool interleaved_; +}; + +static size_t RoPERegisterAllShortExecuteTests() { + return RoPEShortExecuteTest::RegisterShortExecuteTests(); +} + +#ifdef MLAS_TARGET_AMD64 // only test float RoPE with avx2 where RopeDispatch is assigned at this moment. +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return RoPERegisterAllShortExecuteTests(); + } + return 0; + }); +#endif From 754be929dadaa5453f09d1f278ea4231d2e83dd1 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Wed, 19 Feb 2025 18:12:39 +0000 Subject: [PATCH 10/17] lint Signed-off-by: liqunfu --- onnxruntime/test/mlas/unittest/test_rope.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/mlas/unittest/test_rope.cpp b/onnxruntime/test/mlas/unittest/test_rope.cpp index a365ebb5562a0..ff405148c0070 100644 --- a/onnxruntime/test/mlas/unittest/test_rope.cpp +++ b/onnxruntime/test/mlas/unittest/test_rope.cpp @@ -116,7 +116,8 @@ static size_t RoPERegisterAllShortExecuteTests() { return RoPEShortExecuteTest::RegisterShortExecuteTests(); } -#ifdef MLAS_TARGET_AMD64 // only test float RoPE with avx2 where RopeDispatch is assigned at this moment. +// only test float RoPE with avx2 where RopeDispatch is assigned at this moment. +#ifdef MLAS_TARGET_AMD64 static UNUSED_VARIABLE bool added_to_main = AddTestRegister( [](bool is_short_execute) -> size_t { if (is_short_execute) { From 4dc471b4495cdcb04046fd87fc99997325b41c4d Mon Sep 17 00:00:00 2001 From: liqunfu Date: Wed, 19 Feb 2025 18:43:37 +0000 Subject: [PATCH 11/17] fix build - declaration error Signed-off-by: liqunfu --- onnxruntime/core/mlas/inc/mlas_attention.h | 9 ++++----- onnxruntime/core/mlas/lib/rotary_embedding.cpp | 11 ----------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas_attention.h b/onnxruntime/core/mlas/inc/mlas_attention.h index b1cbd246d24fc..54f77e830c700 100644 --- a/onnxruntime/core/mlas/inc/mlas_attention.h +++ b/onnxruntime/core/mlas/inc/mlas_attention.h @@ -20,7 +20,6 @@ Module Name: #include "mlas.h" -#include "mlas_gemm_postprocessor.h" template void MLASCALL @@ -36,10 +35,10 @@ MlasRotaryEmbedOneRow_FallBack( template void MLASCALL MlasRotaryEmbedOneRow( - const float* input, - const float* sin, - const float* cos, + const T* input, + const T* sin, + const T* cos, size_t dim, bool interleaved, - float* output + T* output ); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp index 0f14d6b645115..5cd162597f027 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp @@ -94,14 +94,3 @@ MlasRotaryEmbedOneRow( dispatch->HRope(input, sin, cos, dim, interleaved, output); } - -template <> -void MLASCALL -MlasRotaryEmbedOneRow( - const float* input, - const float* sin, - const float* cos, - size_t dim, - bool interleaved, - float* output -); From 5ecee3c6e19664bee4748bf26e4b2bab5f810608 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Wed, 19 Feb 2025 19:03:31 +0000 Subject: [PATCH 12/17] unused RoPERegisterAllShortExecuteTests ci failure Signed-off-by: liqunfu --- onnxruntime/test/mlas/unittest/test_rope.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_rope.cpp b/onnxruntime/test/mlas/unittest/test_rope.cpp index ff405148c0070..7c2225ff99087 100644 --- a/onnxruntime/test/mlas/unittest/test_rope.cpp +++ b/onnxruntime/test/mlas/unittest/test_rope.cpp @@ -112,12 +112,12 @@ class RoPEShortExecuteTest : public MlasTestFixture { bool interleaved_; }; +// only test float RoPE with avx2 where RopeDispatch is assigned at this moment. +#ifdef MLAS_TARGET_AMD64 static size_t RoPERegisterAllShortExecuteTests() { return RoPEShortExecuteTest::RegisterShortExecuteTests(); } -// only test float RoPE with avx2 where RopeDispatch is assigned at this moment. -#ifdef MLAS_TARGET_AMD64 static UNUSED_VARIABLE bool added_to_main = AddTestRegister( [](bool is_short_execute) -> size_t { if (is_short_execute) { From 62d5eefb3791146a6d049bdf048b139516629659 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Wed, 19 Feb 2025 20:08:21 +0000 Subject: [PATCH 13/17] missing implementation Signed-off-by: liqunfu --- onnxruntime/core/mlas/lib/rotary_embedding.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp index 5cd162597f027..091d5166af2c8 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp @@ -94,3 +94,15 @@ MlasRotaryEmbedOneRow( dispatch->HRope(input, sin, cos, dim, interleaved, output); } + +template +void +MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const float* input_data, + const float* sin_data, + const float* cos_data, + size_t rotary_emb_dim, + bool interleaved, + float* output_data +); From d6de70f96c69692111794ec8cac9e20c6da8537b Mon Sep 17 00:00:00 2001 From: liqunfu Date: Wed, 19 Feb 2025 23:20:09 +0000 Subject: [PATCH 14/17] add benchmark, etc. Signed-off-by: liqunfu --- cmake/onnxruntime_mlas.cmake | 1 - docs/ContribOperators.md | 2 +- docs/OperatorKernels.md | 22 ++++----- onnxruntime/core/mlas/inc/mlas_attention.h | 44 ----------------- onnxruntime/core/mlas/lib/rotary_embedding.h | 11 +++++ onnxruntime/test/mlas/bench/bench_rope.cpp | 52 ++++++++++++++++++++ onnxruntime/test/mlas/unittest/test_rope.cpp | 3 +- 7 files changed, 77 insertions(+), 58 deletions(-) delete mode 100644 onnxruntime/core/mlas/inc/mlas_attention.h create mode 100644 onnxruntime/test/mlas/bench/bench_rope.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 4fd317b72b2ee..15864a0198161 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -51,7 +51,6 @@ target_sources(onnxruntime_mlas PRIVATE ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h ${MLAS_INC_DIR}/mlas_q4.h ${MLAS_INC_DIR}/mlas_qnbit.h - ${MLAS_INC_DIR}/mlas_attention.h ${MLAS_INC_DIR}/mlas.h ) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 571d26bc1903b..274531faaf717 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5253,7 +5253,7 @@ This version of the operator has been available since version 1 of the 'com.micr
interleaved : int
-
Rotate using interleaved pattern. Default value is 0 (False).
+
Indicates whether the input has real and imaginary parts interleaved. Default value is 0 (False), meaning the first half of the input consists of real values and the second half consists of imaginary values.
is_packed_batching : int
ragged batch inputs or not. Default value is 0
num_heads : int
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f565c607c4b37..84b9c7c9fc174 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -117,9 +117,9 @@ Do not modify directly.* |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Elu|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[6, 21]|**T** = tensor(float)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)
**T1** = tensor(bool)| -|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 10]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)| |||[9, 12]|**T** = tensor(float)| @@ -157,11 +157,11 @@ Do not modify directly.* |GlobalLpPool|*in* X:**T**
*out* Y:**T**|2+|**T** = tensor(float)| |GlobalMaxPool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[1, 21]|**T** = tensor(float)| -|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)| -|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|22+|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)| |||[20, 21]|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)| |||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)| @@ -201,11 +201,11 @@ Do not modify directly.* |||[1, 16]|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)
**V** = tensor(double), tensor(float), tensor(float16)| |LeakyRelu|*in* X:**T**
*out* Y:**T**|16+|**T** = tensor(float)| |||[6, 15]|**T** = tensor(float)| -|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)| -|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |Log|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |LogSoftmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| diff --git a/onnxruntime/core/mlas/inc/mlas_attention.h b/onnxruntime/core/mlas/inc/mlas_attention.h deleted file mode 100644 index 54f77e830c700..0000000000000 --- a/onnxruntime/core/mlas/inc/mlas_attention.h +++ /dev/null @@ -1,44 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - mlas_attention.h - -Abstract: - - This module contains the public data structures and procedure prototypes - for attention related ops - - ---*/ - -#pragma once - - -#include "mlas.h" - -template -void MLASCALL -MlasRotaryEmbedOneRow_FallBack( - const T* input_data, - const T* sin_data, - const T* cos_data, - size_t rotary_emb_dim, - bool interleaved, - T* output_data -); - -template -void MLASCALL -MlasRotaryEmbedOneRow( - const T* input, - const T* sin, - const T* cos, - size_t dim, - bool interleaved, - T* output -); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.h b/onnxruntime/core/mlas/lib/rotary_embedding.h index 352dddccf1025..1fcc3dc6af450 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding.h +++ b/onnxruntime/core/mlas/lib/rotary_embedding.h @@ -44,3 +44,14 @@ struct MLAS_ROPE_DISPATCH { HRope_Fn* HRope = nullptr; }; + +template +void MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +); diff --git a/onnxruntime/test/mlas/bench/bench_rope.cpp b/onnxruntime/test/mlas/bench/bench_rope.cpp new file mode 100644 index 0000000000000..1958deb50ae0f --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_rope.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas.h" +#include "benchmark/benchmark.h" +#include "bench_util.h" + +void RunRoPEBenchmark(size_t rotary_emb_dim, bool interleaved, benchmark::State& state) { + const float Pi = 2 * std::acos(0.0f); + + std::vector input(rotary_emb_dim); + size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; + std::vector sin(table_len); + std::vector cos(table_len); + std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); + + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = static_cast(i + 1); + } + for (size_t i = 0; i < table_len; ++i) { + float theta = (float)i / 1000 * Pi; + sin[i] = std::sin(theta); + cos[i] = std::cos(theta); + } + + // warm up run + MlasRotaryEmbedOneRow(&input[0], &sin[0], &cos[0], rotary_emb_dim, interleaved, &output_impl[0]); + + for (auto _ : state) { + MlasRotaryEmbedOneRow(&input[0], &sin[0], &cos[0], rotary_emb_dim, interleaved, &output_impl[0]); + } +} + +void RoPE(benchmark::State& state) { + using onnxruntime::narrow; + + const auto rotary_emb_dim = narrow(state.range(0)); + const auto interleaved = narrow(state.range(1)); + + RunRoPEBenchmark(rotary_emb_dim, interleaved, state); +} + +static void RoPEArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"rotary_emb_dim", "interleaved"}); + + b->ArgsProduct({ + {128, 256, 512, 1024}, // rotary_emb_dim + {int64_t{false}, int64_t{true}}, // interleaved + }); +} + +BENCHMARK(RoPE)->Apply(RoPEArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_rope.cpp b/onnxruntime/test/mlas/unittest/test_rope.cpp index 7c2225ff99087..f9315f1757169 100644 --- a/onnxruntime/test/mlas/unittest/test_rope.cpp +++ b/onnxruntime/test/mlas/unittest/test_rope.cpp @@ -15,7 +15,8 @@ Module Name: --*/ #include "test_util.h" -#include "mlas_attention.h" +#include "mlas.h" +#include "core/mlas/lib/rotary_embedding.h" class MlasRoPETest : public MlasTestBase { const float Pi = 2 * std::acos(0.0f); From 89e1411f5b52f3f096d107ad9bcdd645166087ea Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Wed, 19 Feb 2025 15:26:41 -0800 Subject: [PATCH 15/17] Update onnxruntime/test/mlas/bench/bench_rope.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/test/mlas/bench/bench_rope.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/mlas/bench/bench_rope.cpp b/onnxruntime/test/mlas/bench/bench_rope.cpp index 1958deb50ae0f..2eea7d455b1a6 100644 --- a/onnxruntime/test/mlas/bench/bench_rope.cpp +++ b/onnxruntime/test/mlas/bench/bench_rope.cpp @@ -44,7 +44,7 @@ static void RoPEArgs(benchmark::internal::Benchmark* b) { b->ArgNames({"rotary_emb_dim", "interleaved"}); b->ArgsProduct({ - {128, 256, 512, 1024}, // rotary_emb_dim + {128, 256, 512, 1024}, // rotary_emb_dim {int64_t{false}, int64_t{true}}, // interleaved }); } From b46927efdab416b29d58a45ed52d25cb48780e6d Mon Sep 17 00:00:00 2001 From: liqunfu Date: Thu, 20 Feb 2025 02:03:46 +0000 Subject: [PATCH 16/17] unaligned store Signed-off-by: liqunfu --- .../core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp index 76be2a23c64a6..dc5a80653d4df 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp @@ -55,8 +55,8 @@ RopeKernel_Avx2_Impl( float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); //Store back into non interleaved format - _mm256_store_ps(output + i, real_out); - _mm256_store_ps(output + j, imag_out); + _mm256_storeu_ps(output + i, real_out); + _mm256_storeu_ps(output + j, imag_out); } if (half_dim - i != 0) { size_t rem = half_dim - i; @@ -106,8 +106,8 @@ RopeKernel_Avx2_Impl( float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); - _mm256_store_ps(output + i, y0); - _mm256_store_ps(output + i + 8, y1); + _mm256_storeu_ps(output + i, y0); + _mm256_storeu_ps(output + i + 8, y1); } if (dim - i != 0) { size_t rem = dim - i; From 484d1285768994aba7bff865f82e9c3f7dd964ab Mon Sep 17 00:00:00 2001 From: liqunfu Date: Thu, 20 Feb 2025 17:11:33 +0000 Subject: [PATCH 17/17] sin ->sin_data, constexpr Signed-off-by: liqunfu --- onnxruntime/core/mlas/inc/mlas.h | 4 +- .../core/mlas/lib/rotary_embedding.cpp | 16 ++++---- onnxruntime/core/mlas/lib/rotary_embedding.h | 8 ++-- .../mlas/lib/rotary_embedding_kernel_avx2.h | 4 +- .../lib/rotary_embedding_kernel_avx2_fp32.cpp | 40 +++++++++---------- onnxruntime/test/mlas/bench/bench_rope.cpp | 12 +++--- onnxruntime/test/mlas/unittest/test_rope.cpp | 12 +++--- 7 files changed, 48 insertions(+), 48 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index f1d33b3bdd66e..1d18354060abb 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1464,8 +1464,8 @@ void MLASCALL MlasRotaryEmbedOneRow( const T* input, - const T* sin, - const T* cos, + const T* sin_data, + const T* cos_data, size_t dim, bool interleaved, T* output diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp index 091d5166af2c8..63e0a7fd707f7 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp @@ -58,8 +58,8 @@ void MLASCALL MlasRotaryEmbedOneRow( const float* input, - const float* sin, - const float* cos, + const float* sin_data, + const float* cos_data, size_t dim, bool interleaved, float* output @@ -67,11 +67,11 @@ MlasRotaryEmbedOneRow( const auto* dispatch = GetMlasPlatform().RopeDispatch; if (dispatch == nullptr || dispatch->SRope == nullptr) { - MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output); return; } - dispatch->SRope(input, sin, cos, dim, interleaved, output); + dispatch->SRope(input, sin_data, cos_data, dim, interleaved, output); } template <> @@ -79,8 +79,8 @@ void MLASCALL MlasRotaryEmbedOneRow( const MLAS_FP16* input, - const MLAS_FP16* sin, - const MLAS_FP16* cos, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, size_t dim, bool interleaved, MLAS_FP16* output @@ -88,11 +88,11 @@ MlasRotaryEmbedOneRow( const auto* dispatch = GetMlasPlatform().RopeDispatch; if (dispatch == nullptr || dispatch->HRope == nullptr) { - MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output); return; } - dispatch->HRope(input, sin, cos, dim, interleaved, output); + dispatch->HRope(input, sin_data, cos_data, dim, interleaved, output); } template diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.h b/onnxruntime/core/mlas/lib/rotary_embedding.h index 1fcc3dc6af450..c017ece810d0f 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding.h +++ b/onnxruntime/core/mlas/lib/rotary_embedding.h @@ -23,8 +23,8 @@ struct MLAS_ROPE_DISPATCH { // rotary embedding kernel for fp32 typedef void(SRope_Fn)( const float* input, - const float* sin, - const float* cos, + const float* sin_data, + const float* cos_data, size_t dim, bool interleaved, float* output @@ -35,8 +35,8 @@ struct MLAS_ROPE_DISPATCH { // rotary embedding kernel for fp16 typedef void(HRope_Fn)( const MLAS_FP16* input, - const MLAS_FP16* sin, - const MLAS_FP16* cos, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, size_t dim, bool interleaved, MLAS_FP16* output diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h index 69448a033f829..18a2e11998644 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h @@ -27,8 +27,8 @@ namespace rope_avx2 { void RopeKernel_Avx2( const float* input, - const float* sin, - const float* cos, + const float* sin_data, + const float* cos_data, size_t dim, bool interleaved, float* output diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp index dc5a80653d4df..7124b82606978 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp @@ -29,8 +29,8 @@ template void RopeKernel_Avx2_Impl( const float* input, - const float* sin, - const float* cos, + const float* sin_data, + const float* cos_data, size_t dim, float* output ); @@ -39,8 +39,8 @@ template <> void RopeKernel_Avx2_Impl( const float* input, - const float* sin, - const float* cos, + const float* sin_data, + const float* cos_data, size_t dim, float* output ) { @@ -49,8 +49,8 @@ RopeKernel_Avx2_Impl( for (; i + 7 < half_dim; i += 8, j += 8) { float32x8_t real = _mm256_loadu_ps(input + i); float32x8_t imag = _mm256_loadu_ps(input + j); - float32x8_t sin_val = _mm256_loadu_ps(sin + i); - float32x8_t cos_val = _mm256_loadu_ps(cos + i); + float32x8_t sin_val = _mm256_loadu_ps(sin_data + i); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i); //Compute Real and Imaginary output values float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); @@ -60,13 +60,13 @@ RopeKernel_Avx2_Impl( } if (half_dim - i != 0) { size_t rem = half_dim - i; - static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + static constexpr int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; const __m256i mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - rem)); //Use a mask to load the remaining input values float32x8_t real = _mm256_maskload_ps(input + i, mask); float32x8_t imag = _mm256_maskload_ps(input + j, mask); - float32x8_t sin_val = _mm256_maskload_ps(sin + i, mask); - float32x8_t cos_val = _mm256_maskload_ps(cos + i, mask); + float32x8_t sin_val = _mm256_maskload_ps(sin_data + i, mask); + float32x8_t cos_val = _mm256_maskload_ps(cos_data + i, mask); //Compute Real and Imaginary output values float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); @@ -80,8 +80,8 @@ template <> void RopeKernel_Avx2_Impl( const float* input, - const float* sin, - const float* cos, + const float* sin_data, + const float* cos_data, size_t dim, float* output ) { @@ -95,8 +95,8 @@ RopeKernel_Avx2_Impl( __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); - float32x8_t sin_val = _mm256_loadu_ps(sin + i / 2); - float32x8_t cos_val = _mm256_loadu_ps(cos + i / 2); + float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); //Compute Real and Imaginary output values float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); @@ -111,7 +111,7 @@ RopeKernel_Avx2_Impl( } if (dim - i != 0) { size_t rem = dim - i; - static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + static constexpr int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; const __m256i mask0 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?8:rem))); const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0))); float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask @@ -122,8 +122,8 @@ RopeKernel_Avx2_Impl( __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); - float32x8_t sin_val = _mm256_loadu_ps(sin + i / 2); - float32x8_t cos_val = _mm256_loadu_ps(cos + i / 2); + float32x8_t sin_val = _mm256_loadu_ps(sin_data+ i / 2); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); //Compute Real and Imaginary output values float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); @@ -143,8 +143,8 @@ RopeKernel_Avx2_Impl( void RopeKernel_Avx2( const float* input, - const float* sin, - const float* cos, + const float* sin_data, + const float* cos_data, size_t dim, bool interleaved, float* output @@ -152,8 +152,8 @@ RopeKernel_Avx2( // real part and imaginary part must be paired assert(dim % 2 == 0); const auto* input_impl = reinterpret_cast(input); - const auto* sin_impl = reinterpret_cast(sin); - const auto* cos_impl = reinterpret_cast(cos); + const auto* sin_impl = reinterpret_cast(sin_data); + const auto* cos_impl = reinterpret_cast(cos_data); auto* output_impl = reinterpret_cast(output); if (interleaved) { diff --git a/onnxruntime/test/mlas/bench/bench_rope.cpp b/onnxruntime/test/mlas/bench/bench_rope.cpp index 2eea7d455b1a6..9103ba6424f65 100644 --- a/onnxruntime/test/mlas/bench/bench_rope.cpp +++ b/onnxruntime/test/mlas/bench/bench_rope.cpp @@ -10,8 +10,8 @@ void RunRoPEBenchmark(size_t rotary_emb_dim, bool interleaved, benchmark::State& std::vector input(rotary_emb_dim); size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; - std::vector sin(table_len); - std::vector cos(table_len); + std::vector sin_data(table_len); + std::vector cos_data(table_len); std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); for (size_t i = 0; i < rotary_emb_dim; ++i) { @@ -19,15 +19,15 @@ void RunRoPEBenchmark(size_t rotary_emb_dim, bool interleaved, benchmark::State& } for (size_t i = 0; i < table_len; ++i) { float theta = (float)i / 1000 * Pi; - sin[i] = std::sin(theta); - cos[i] = std::cos(theta); + sin_data[i] = std::sin(theta); + cos_data[i] = std::cos(theta); } // warm up run - MlasRotaryEmbedOneRow(&input[0], &sin[0], &cos[0], rotary_emb_dim, interleaved, &output_impl[0]); + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); for (auto _ : state) { - MlasRotaryEmbedOneRow(&input[0], &sin[0], &cos[0], rotary_emb_dim, interleaved, &output_impl[0]); + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); } } diff --git a/onnxruntime/test/mlas/unittest/test_rope.cpp b/onnxruntime/test/mlas/unittest/test_rope.cpp index f9315f1757169..54087a933fd9e 100644 --- a/onnxruntime/test/mlas/unittest/test_rope.cpp +++ b/onnxruntime/test/mlas/unittest/test_rope.cpp @@ -25,8 +25,8 @@ class MlasRoPETest : public MlasTestBase { void Test(size_t rotary_emb_dim, bool interleaved) { std::vector input(rotary_emb_dim); size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; - std::vector sin(table_len); - std::vector cos(table_len); + std::vector sin_data(table_len); + std::vector cos_data(table_len); std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); for (size_t i = 0; i < rotary_emb_dim; ++i) { @@ -34,13 +34,13 @@ class MlasRoPETest : public MlasTestBase { } for (size_t i = 0; i < table_len; ++i) { float theta = (float)i / 1000 * Pi; - sin[i] = std::sin(theta); - cos[i] = std::cos(theta); + sin_data[i] = std::sin(theta); + cos_data[i] = std::cos(theta); } // Call the function - MlasRotaryEmbedOneRow_FallBack(&input[0], &sin[0], &cos[0], rotary_emb_dim, interleaved, &output_ref[0]); - MlasRotaryEmbedOneRow(&input[0], &sin[0], &cos[0], rotary_emb_dim, interleaved, &output_impl[0]); + MlasRotaryEmbedOneRow_FallBack(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_ref[0]); + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); for (size_t i = 0; i < rotary_emb_dim; i++) { ASSERT_TRUE(CloseEnough(output_impl[i], output_ref[i]))