diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index cb8d3dcad511f..15864a0198161 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -181,6 +181,9 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/dgemm.cpp ${mlas_platform_srcs_avx} ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2_fp32.cpp ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp @@ -226,6 +229,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm ) + if(MSVC_VERSION GREATER_EQUAL 1933) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm @@ -594,6 +598,9 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2_fp32.cpp ) if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) set(mlas_platform_srcs_avx2 diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 571d26bc1903b..274531faaf717 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5253,7 +5253,7 @@ This version of the operator has been available since version 1 of the 'com.micr
interleaved : int
-
Rotate using interleaved pattern. Default value is 0 (False).
+
Indicates whether the input has real and imaginary parts interleaved. Default value is 0 (False), meaning the first half of the input consists of real values and the second half consists of imaginary values.
is_packed_batching : int
ragged batch inputs or not. Default value is 0
num_heads : int
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f565c607c4b37..84b9c7c9fc174 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -117,9 +117,9 @@ Do not modify directly.* |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Elu|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[6, 21]|**T** = tensor(float)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)
**T1** = tensor(bool)| -|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 10]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)| |||[9, 12]|**T** = tensor(float)| @@ -157,11 +157,11 @@ Do not modify directly.* |GlobalLpPool|*in* X:**T**
*out* Y:**T**|2+|**T** = tensor(float)| |GlobalMaxPool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[1, 21]|**T** = tensor(float)| -|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)| -|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|22+|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)| |||[20, 21]|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)| |||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)| @@ -201,11 +201,11 @@ Do not modify directly.* |||[1, 16]|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)
**V** = tensor(double), tensor(float), tensor(float16)| |LeakyRelu|*in* X:**T**
*out* Y:**T**|16+|**T** = tensor(float)| |||[6, 15]|**T** = tensor(float)| -|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)| -|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |Log|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |LogSoftmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index d8f8d28a47621..ecc8cb091b1b6 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1344,7 +1344,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::FLOAT, OPTIONAL_VALUE) .Attr("interleaved", - "Rotate using interleaved pattern. Default value is 0 (False).", + "Indicates whether the input has real and imaginary parts interleaved. " + "Default value is 0 (False), meaning the first half of the input consists of real values " + "and the second half consists of imaginary values.", AttributeProto::INT, OPTIONAL_VALUE) .Attr("rotary_embedding_dim", diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index f1d33b3bdd66e..1d18354060abb 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1464,8 +1464,8 @@ void MLASCALL MlasRotaryEmbedOneRow( const T* input, - const T* sin, - const T* cos, + const T* sin_data, + const T* cos_data, size_t dim, bool interleaved, T* output diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index decf44708de32..0dae9f6ccaee1 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1058,6 +1058,7 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; // struct MLAS_ROPE_DISPATCH; extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2; // // half gemm dispatch structure diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 9db3115c1ad20..582c1ab944b98 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -401,6 +401,7 @@ Return Value: this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; + this->RopeDispatch = &MlasRopeDispatchAvx2; // diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp index 1f8f7b240694c..63e0a7fd707f7 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp @@ -16,8 +16,6 @@ Module Name: #include "rotary_embedding.h" -namespace { - template void MLASCALL @@ -55,16 +53,13 @@ MlasRotaryEmbedOneRow_FallBack( } } -} // namespace - - template <> void MLASCALL MlasRotaryEmbedOneRow( const float* input, - const float* sin, - const float* cos, + const float* sin_data, + const float* cos_data, size_t dim, bool interleaved, float* output @@ -72,11 +67,11 @@ MlasRotaryEmbedOneRow( const auto* dispatch = GetMlasPlatform().RopeDispatch; if (dispatch == nullptr || dispatch->SRope == nullptr) { - MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output); return; } - dispatch->SRope(input, sin, cos, dim, interleaved, output); + dispatch->SRope(input, sin_data, cos_data, dim, interleaved, output); } template <> @@ -84,8 +79,8 @@ void MLASCALL MlasRotaryEmbedOneRow( const MLAS_FP16* input, - const MLAS_FP16* sin, - const MLAS_FP16* cos, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, size_t dim, bool interleaved, MLAS_FP16* output @@ -93,9 +88,21 @@ MlasRotaryEmbedOneRow( const auto* dispatch = GetMlasPlatform().RopeDispatch; if (dispatch == nullptr || dispatch->HRope == nullptr) { - MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output); return; } - dispatch->HRope(input, sin, cos, dim, interleaved, output); + dispatch->HRope(input, sin_data, cos_data, dim, interleaved, output); } + +template +void +MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const float* input_data, + const float* sin_data, + const float* cos_data, + size_t rotary_emb_dim, + bool interleaved, + float* output_data +); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.h b/onnxruntime/core/mlas/lib/rotary_embedding.h index 352dddccf1025..c017ece810d0f 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding.h +++ b/onnxruntime/core/mlas/lib/rotary_embedding.h @@ -23,8 +23,8 @@ struct MLAS_ROPE_DISPATCH { // rotary embedding kernel for fp32 typedef void(SRope_Fn)( const float* input, - const float* sin, - const float* cos, + const float* sin_data, + const float* cos_data, size_t dim, bool interleaved, float* output @@ -35,8 +35,8 @@ struct MLAS_ROPE_DISPATCH { // rotary embedding kernel for fp16 typedef void(HRope_Fn)( const MLAS_FP16* input, - const MLAS_FP16* sin, - const MLAS_FP16* cos, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, size_t dim, bool interleaved, MLAS_FP16* output @@ -44,3 +44,14 @@ struct MLAS_ROPE_DISPATCH { HRope_Fn* HRope = nullptr; }; + +template +void MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp new file mode 100644 index 0000000000000..7b6c49720853a --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp @@ -0,0 +1,27 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2.cpp + +Abstract: + + This module implements the rotary embedding kernels for AVX2 supported h/w. + +--*/ + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_avx2.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2 = []() { + MLAS_ROPE_DISPATCH d; + d.SRope = rope_avx2::RopeKernel_Avx2; + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h new file mode 100644 index 0000000000000..18a2e11998644 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2.h + +Abstract: + + This module includes function declarations and common helper functions for + rotary embedding on for AVX2 enabled h/w. + +--*/ + +#pragma once + + + +#include "mlasi.h" + +namespace rope_avx2 { + +// Rotary embedding kernel for FP32. Embed one hidden state vector. +void +RopeKernel_Avx2( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +); + +} // namespace rope_avx2 diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp new file mode 100644 index 0000000000000..7124b82606978 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp @@ -0,0 +1,166 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2_fp32.cpp + +Abstract: + + This module implements the fp32 rotary embedding kernels using AVX2. + +--*/ + +#include + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_avx2.h" + +namespace rope_avx2 { + +namespace { + +typedef __m256 float32x8_t; + +template +void +RopeKernel_Avx2_Impl( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + float* output +); + +template <> +void +RopeKernel_Avx2_Impl( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + float* output +) { + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + for (; i + 7 < half_dim; i += 8, j += 8) { + float32x8_t real = _mm256_loadu_ps(input + i); + float32x8_t imag = _mm256_loadu_ps(input + j); + float32x8_t sin_val = _mm256_loadu_ps(sin_data + i); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into non interleaved format + _mm256_storeu_ps(output + i, real_out); + _mm256_storeu_ps(output + j, imag_out); + } + if (half_dim - i != 0) { + size_t rem = half_dim - i; + static constexpr int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + const __m256i mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - rem)); + //Use a mask to load the remaining input values + float32x8_t real = _mm256_maskload_ps(input + i, mask); + float32x8_t imag = _mm256_maskload_ps(input + j, mask); + float32x8_t sin_val = _mm256_maskload_ps(sin_data + i, mask); + float32x8_t cos_val = _mm256_maskload_ps(cos_data + i, mask); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into non interleaved format + _mm256_maskstore_ps(output + i, mask, real_out); + _mm256_maskstore_ps(output + j, mask, imag_out); + } +} + +template <> +void +RopeKernel_Avx2_Impl( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + float* output +) { + size_t i = 0; + for (; i + 15 < dim; i += 16) { + float32x8_t x0 = _mm256_loadu_ps(input + i); + float32x8_t x1 = _mm256_loadu_ps(input + i + 8); + //Load imaginary and real values to separate non-interleaved vectors + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + _mm256_storeu_ps(output + i, y0); + _mm256_storeu_ps(output + i + 8, y1); + } + if (dim - i != 0) { + size_t rem = dim - i; + static constexpr int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + const __m256i mask0 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?8:rem))); + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0))); + float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask + float32x8_t x1 = _mm256_maskload_ps(input + i + 8, mask1); //Load the reminder of data using a second mask + //Load imaginary and real values to separate non-interleaved vectors + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + float32x8_t sin_val = _mm256_loadu_ps(sin_data+ i / 2); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + _mm256_maskstore_ps(output + i, mask0, y0); + _mm256_maskstore_ps(output + i + 8, mask1, y1); + } +} + +} // namespace + +void +RopeKernel_Avx2( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +) { + // real part and imaginary part must be paired + assert(dim % 2 == 0); + const auto* input_impl = reinterpret_cast(input); + const auto* sin_impl = reinterpret_cast(sin_data); + const auto* cos_impl = reinterpret_cast(cos_data); + auto* output_impl = reinterpret_cast(output); + + if (interleaved) { + RopeKernel_Avx2_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } else { + RopeKernel_Avx2_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } +} + +} diff --git a/onnxruntime/test/mlas/bench/bench_rope.cpp b/onnxruntime/test/mlas/bench/bench_rope.cpp new file mode 100644 index 0000000000000..9103ba6424f65 --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_rope.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas.h" +#include "benchmark/benchmark.h" +#include "bench_util.h" + +void RunRoPEBenchmark(size_t rotary_emb_dim, bool interleaved, benchmark::State& state) { + const float Pi = 2 * std::acos(0.0f); + + std::vector input(rotary_emb_dim); + size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); + + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = static_cast(i + 1); + } + for (size_t i = 0; i < table_len; ++i) { + float theta = (float)i / 1000 * Pi; + sin_data[i] = std::sin(theta); + cos_data[i] = std::cos(theta); + } + + // warm up run + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); + + for (auto _ : state) { + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); + } +} + +void RoPE(benchmark::State& state) { + using onnxruntime::narrow; + + const auto rotary_emb_dim = narrow(state.range(0)); + const auto interleaved = narrow(state.range(1)); + + RunRoPEBenchmark(rotary_emb_dim, interleaved, state); +} + +static void RoPEArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"rotary_emb_dim", "interleaved"}); + + b->ArgsProduct({ + {128, 256, 512, 1024}, // rotary_emb_dim + {int64_t{false}, int64_t{true}}, // interleaved + }); +} + +BENCHMARK(RoPE)->Apply(RoPEArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_rope.cpp b/onnxruntime/test/mlas/unittest/test_rope.cpp new file mode 100644 index 0000000000000..54087a933fd9e --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_rope.cpp @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_rope.h + +Abstract: + + Tests for MLAS RoPE. + +--*/ + +#include "test_util.h" +#include "mlas.h" +#include "core/mlas/lib/rotary_embedding.h" + +class MlasRoPETest : public MlasTestBase { + const float Pi = 2 * std::acos(0.0f); + + public: + void Test(size_t rotary_emb_dim, bool interleaved) { + std::vector input(rotary_emb_dim); + size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); + + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = static_cast(i + 1); + } + for (size_t i = 0; i < table_len; ++i) { + float theta = (float)i / 1000 * Pi; + sin_data[i] = std::sin(theta); + cos_data[i] = std::cos(theta); + } + + // Call the function + MlasRotaryEmbedOneRow_FallBack(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_ref[0]); + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); + + for (size_t i = 0; i < rotary_emb_dim; i++) { + ASSERT_TRUE(CloseEnough(output_impl[i], output_ref[i])) + << "Expected: " << output_ref[i] << " Actual: " << output_impl[i] << "@[" << i << "], " + << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved; + } + } + + public: +}; + +// +// Short Execute() test helper to register each test separately by all parameters. +// +class RoPEShortExecuteTest : public MlasTestFixture { + public: + explicit RoPEShortExecuteTest(size_t rotary_emb_dim, bool interleaved) + : rotary_emb_dim_(rotary_emb_dim), + interleaved_(interleaved) {} + + void TestBody() override { + MlasTestFixture::mlas_tester->Test(rotary_emb_dim_, interleaved_); + } + + static size_t RegisterSingleTest(size_t rotary_emb_dim, bool interleaved) { + size_t tests_registered = 0; + + std::stringstream ss; + ss << "/rotary_emb_dim" << rotary_emb_dim << "/interleaved" << interleaved; + auto test_name = ss.str(); + + testing::RegisterTest( + "RoPE", + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture* { + return new RoPEShortExecuteTest(rotary_emb_dim, interleaved); + }); + + tests_registered += 1; + + return tests_registered; + } + + static size_t RegisterShortExecuteTests() { + size_t tests_registered = 0; + tests_registered += RegisterSingleTest(6, false); + tests_registered += RegisterSingleTest(6, true); + tests_registered += RegisterSingleTest(16, false); + tests_registered += RegisterSingleTest(16, true); + tests_registered += RegisterSingleTest(24, false); + tests_registered += RegisterSingleTest(24, true); + tests_registered += RegisterSingleTest(32, false); + tests_registered += RegisterSingleTest(32, true); + tests_registered += RegisterSingleTest(42, false); + tests_registered += RegisterSingleTest(42, true); + tests_registered += RegisterSingleTest(64, false); + tests_registered += RegisterSingleTest(64, true); + tests_registered += RegisterSingleTest(70, false); + tests_registered += RegisterSingleTest(70, true); + return tests_registered; + } + + private: + size_t rotary_emb_dim_; + bool interleaved_; +}; + +// only test float RoPE with avx2 where RopeDispatch is assigned at this moment. +#ifdef MLAS_TARGET_AMD64 +static size_t RoPERegisterAllShortExecuteTests() { + return RoPEShortExecuteTest::RegisterShortExecuteTests(); +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return RoPERegisterAllShortExecuteTests(); + } + return 0; + }); +#endif