diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index cb8d3dcad511f..15864a0198161 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -181,6 +181,9 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/dgemm.cpp
${mlas_platform_srcs_avx}
${mlas_platform_srcs_avx2}
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2_fp32.cpp
${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
@@ -226,6 +229,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm
${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm
)
+
if(MSVC_VERSION GREATER_EQUAL 1933)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm
@@ -594,6 +598,9 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2_fp32.cpp
)
if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE))
set(mlas_platform_srcs_avx2
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 571d26bc1903b..274531faaf717 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -5253,7 +5253,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- interleaved : int
-- Rotate using interleaved pattern. Default value is 0 (False).
+- Indicates whether the input has real and imaginary parts interleaved. Default value is 0 (False), meaning the first half of the input consists of real values and the second half consists of imaginary values.
- is_packed_batching : int
- ragged batch inputs or not. Default value is 0
- num_heads : int
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index f565c607c4b37..84b9c7c9fc174 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -117,9 +117,9 @@ Do not modify directly.*
|Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|Elu|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)|
|||[6, 21]|**T** = tensor(float)|
-|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)
**T1** = tensor(bool)|
-|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
-|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
+|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
|||[7, 10]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
|Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)|
|||[9, 12]|**T** = tensor(float)|
@@ -157,11 +157,11 @@ Do not modify directly.*
|GlobalLpPool|*in* X:**T**
*out* Y:**T**|2+|**T** = tensor(float)|
|GlobalMaxPool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)|
|||[1, 21]|**T** = tensor(float)|
-|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
-|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
+|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
|||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)|
-|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
-|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
+|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|22+|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)|
|||[20, 21]|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)|
|||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)|
@@ -201,11 +201,11 @@ Do not modify directly.*
|||[1, 16]|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)
**V** = tensor(double), tensor(float), tensor(float16)|
|LeakyRelu|*in* X:**T**
*out* Y:**T**|16+|**T** = tensor(float)|
|||[6, 15]|**T** = tensor(float)|
-|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
-|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
+|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
|||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)|
-|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
-|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
+|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
|Log|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
|LogSoftmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)|
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index d8f8d28a47621..ecc8cb091b1b6 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1344,7 +1344,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr("interleaved",
- "Rotate using interleaved pattern. Default value is 0 (False).",
+ "Indicates whether the input has real and imaginary parts interleaved. "
+ "Default value is 0 (False), meaning the first half of the input consists of real values "
+ "and the second half consists of imaginary values.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("rotary_embedding_dim",
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index f1d33b3bdd66e..1d18354060abb 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -1464,8 +1464,8 @@ void
MLASCALL
MlasRotaryEmbedOneRow(
const T* input,
- const T* sin,
- const T* cos,
+ const T* sin_data,
+ const T* cos_data,
size_t dim,
bool interleaved,
T* output
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index decf44708de32..0dae9f6ccaee1 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -1058,6 +1058,7 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
//
struct MLAS_ROPE_DISPATCH;
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;
+extern const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2;
//
// half gemm dispatch structure
diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp
index 9db3115c1ad20..582c1ab944b98 100644
--- a/onnxruntime/core/mlas/lib/platform.cpp
+++ b/onnxruntime/core/mlas/lib/platform.cpp
@@ -401,6 +401,7 @@ Return Value:
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2;
+ this->RopeDispatch = &MlasRopeDispatchAvx2;
//
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp
index 1f8f7b240694c..63e0a7fd707f7 100644
--- a/onnxruntime/core/mlas/lib/rotary_embedding.cpp
+++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp
@@ -16,8 +16,6 @@ Module Name:
#include "rotary_embedding.h"
-namespace {
-
template
void
MLASCALL
@@ -55,16 +53,13 @@ MlasRotaryEmbedOneRow_FallBack(
}
}
-} // namespace
-
-
template <>
void
MLASCALL
MlasRotaryEmbedOneRow(
const float* input,
- const float* sin,
- const float* cos,
+ const float* sin_data,
+ const float* cos_data,
size_t dim,
bool interleaved,
float* output
@@ -72,11 +67,11 @@ MlasRotaryEmbedOneRow(
const auto* dispatch = GetMlasPlatform().RopeDispatch;
if (dispatch == nullptr || dispatch->SRope == nullptr) {
- MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output);
+ MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output);
return;
}
- dispatch->SRope(input, sin, cos, dim, interleaved, output);
+ dispatch->SRope(input, sin_data, cos_data, dim, interleaved, output);
}
template <>
@@ -84,8 +79,8 @@ void
MLASCALL
MlasRotaryEmbedOneRow(
const MLAS_FP16* input,
- const MLAS_FP16* sin,
- const MLAS_FP16* cos,
+ const MLAS_FP16* sin_data,
+ const MLAS_FP16* cos_data,
size_t dim,
bool interleaved,
MLAS_FP16* output
@@ -93,9 +88,21 @@ MlasRotaryEmbedOneRow(
const auto* dispatch = GetMlasPlatform().RopeDispatch;
if (dispatch == nullptr || dispatch->HRope == nullptr) {
- MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output);
+ MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output);
return;
}
- dispatch->HRope(input, sin, cos, dim, interleaved, output);
+ dispatch->HRope(input, sin_data, cos_data, dim, interleaved, output);
}
+
+template
+void
+MLASCALL
+MlasRotaryEmbedOneRow_FallBack(
+ const float* input_data,
+ const float* sin_data,
+ const float* cos_data,
+ size_t rotary_emb_dim,
+ bool interleaved,
+ float* output_data
+);
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.h b/onnxruntime/core/mlas/lib/rotary_embedding.h
index 352dddccf1025..c017ece810d0f 100644
--- a/onnxruntime/core/mlas/lib/rotary_embedding.h
+++ b/onnxruntime/core/mlas/lib/rotary_embedding.h
@@ -23,8 +23,8 @@ struct MLAS_ROPE_DISPATCH {
// rotary embedding kernel for fp32
typedef void(SRope_Fn)(
const float* input,
- const float* sin,
- const float* cos,
+ const float* sin_data,
+ const float* cos_data,
size_t dim,
bool interleaved,
float* output
@@ -35,8 +35,8 @@ struct MLAS_ROPE_DISPATCH {
// rotary embedding kernel for fp16
typedef void(HRope_Fn)(
const MLAS_FP16* input,
- const MLAS_FP16* sin,
- const MLAS_FP16* cos,
+ const MLAS_FP16* sin_data,
+ const MLAS_FP16* cos_data,
size_t dim,
bool interleaved,
MLAS_FP16* output
@@ -44,3 +44,14 @@ struct MLAS_ROPE_DISPATCH {
HRope_Fn* HRope = nullptr;
};
+
+template
+void MLASCALL
+MlasRotaryEmbedOneRow_FallBack(
+ const T* input_data,
+ const T* sin_data,
+ const T* cos_data,
+ size_t rotary_emb_dim,
+ bool interleaved,
+ T* output_data
+);
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp
new file mode 100644
index 0000000000000..7b6c49720853a
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp
@@ -0,0 +1,27 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ rotary_embedding_kernel_avx2.cpp
+
+Abstract:
+
+ This module implements the rotary embedding kernels for AVX2 supported h/w.
+
+--*/
+
+#include "rotary_embedding.h"
+#include "rotary_embedding_kernel_avx2.h"
+
+//
+// Kernel dispatch structure definition.
+//
+const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2 = []() {
+ MLAS_ROPE_DISPATCH d;
+ d.SRope = rope_avx2::RopeKernel_Avx2;
+ return d;
+}();
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h
new file mode 100644
index 0000000000000..18a2e11998644
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h
@@ -0,0 +1,37 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ rotary_embedding_kernel_avx2.h
+
+Abstract:
+
+ This module includes function declarations and common helper functions for
+ rotary embedding on for AVX2 enabled h/w.
+
+--*/
+
+#pragma once
+
+
+
+#include "mlasi.h"
+
+namespace rope_avx2 {
+
+// Rotary embedding kernel for FP32. Embed one hidden state vector.
+void
+RopeKernel_Avx2(
+ const float* input,
+ const float* sin_data,
+ const float* cos_data,
+ size_t dim,
+ bool interleaved,
+ float* output
+);
+
+} // namespace rope_avx2
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp
new file mode 100644
index 0000000000000..7124b82606978
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp
@@ -0,0 +1,166 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ rotary_embedding_kernel_avx2_fp32.cpp
+
+Abstract:
+
+ This module implements the fp32 rotary embedding kernels using AVX2.
+
+--*/
+
+#include
+
+#include "rotary_embedding.h"
+#include "rotary_embedding_kernel_avx2.h"
+
+namespace rope_avx2 {
+
+namespace {
+
+typedef __m256 float32x8_t;
+
+template
+void
+RopeKernel_Avx2_Impl(
+ const float* input,
+ const float* sin_data,
+ const float* cos_data,
+ size_t dim,
+ float* output
+);
+
+template <>
+void
+RopeKernel_Avx2_Impl(
+ const float* input,
+ const float* sin_data,
+ const float* cos_data,
+ size_t dim,
+ float* output
+) {
+ const size_t half_dim = dim >> 1;
+ size_t i = 0, j = half_dim;
+ for (; i + 7 < half_dim; i += 8, j += 8) {
+ float32x8_t real = _mm256_loadu_ps(input + i);
+ float32x8_t imag = _mm256_loadu_ps(input + j);
+ float32x8_t sin_val = _mm256_loadu_ps(sin_data + i);
+ float32x8_t cos_val = _mm256_loadu_ps(cos_data + i);
+ //Compute Real and Imaginary output values
+ float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
+ float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
+ //Store back into non interleaved format
+ _mm256_storeu_ps(output + i, real_out);
+ _mm256_storeu_ps(output + j, imag_out);
+ }
+ if (half_dim - i != 0) {
+ size_t rem = half_dim - i;
+ static constexpr int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0};
+ const __m256i mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - rem));
+ //Use a mask to load the remaining input values
+ float32x8_t real = _mm256_maskload_ps(input + i, mask);
+ float32x8_t imag = _mm256_maskload_ps(input + j, mask);
+ float32x8_t sin_val = _mm256_maskload_ps(sin_data + i, mask);
+ float32x8_t cos_val = _mm256_maskload_ps(cos_data + i, mask);
+ //Compute Real and Imaginary output values
+ float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
+ float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
+ //Store back into non interleaved format
+ _mm256_maskstore_ps(output + i, mask, real_out);
+ _mm256_maskstore_ps(output + j, mask, imag_out);
+ }
+}
+
+template <>
+void
+RopeKernel_Avx2_Impl(
+ const float* input,
+ const float* sin_data,
+ const float* cos_data,
+ size_t dim,
+ float* output
+) {
+ size_t i = 0;
+ for (; i + 15 < dim; i += 16) {
+ float32x8_t x0 = _mm256_loadu_ps(input + i);
+ float32x8_t x1 = _mm256_loadu_ps(input + i + 8);
+ //Load imaginary and real values to separate non-interleaved vectors
+ float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000);
+ float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101);
+ __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
+ float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec);
+ float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec);
+ float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2);
+ float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2);
+ //Compute Real and Imaginary output values
+ float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
+ float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
+ //Store back into interleaved format
+ __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
+ float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec);
+ float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec);
+ float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s);
+ float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s);
+ _mm256_storeu_ps(output + i, y0);
+ _mm256_storeu_ps(output + i + 8, y1);
+ }
+ if (dim - i != 0) {
+ size_t rem = dim - i;
+ static constexpr int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0};
+ const __m256i mask0 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?8:rem)));
+ const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0)));
+ float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask
+ float32x8_t x1 = _mm256_maskload_ps(input + i + 8, mask1); //Load the reminder of data using a second mask
+ //Load imaginary and real values to separate non-interleaved vectors
+ float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000);
+ float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101);
+ __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
+ float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec);
+ float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec);
+ float32x8_t sin_val = _mm256_loadu_ps(sin_data+ i / 2);
+ float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2);
+ //Compute Real and Imaginary output values
+ float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
+ float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
+ //Store back into interleaved format
+ __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
+ float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec);
+ float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec);
+ float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s);
+ float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s);
+ _mm256_maskstore_ps(output + i, mask0, y0);
+ _mm256_maskstore_ps(output + i + 8, mask1, y1);
+ }
+}
+
+} // namespace
+
+void
+RopeKernel_Avx2(
+ const float* input,
+ const float* sin_data,
+ const float* cos_data,
+ size_t dim,
+ bool interleaved,
+ float* output
+) {
+ // real part and imaginary part must be paired
+ assert(dim % 2 == 0);
+ const auto* input_impl = reinterpret_cast(input);
+ const auto* sin_impl = reinterpret_cast(sin_data);
+ const auto* cos_impl = reinterpret_cast(cos_data);
+ auto* output_impl = reinterpret_cast(output);
+
+ if (interleaved) {
+ RopeKernel_Avx2_Impl(input_impl, sin_impl, cos_impl, dim, output_impl);
+ } else {
+ RopeKernel_Avx2_Impl(input_impl, sin_impl, cos_impl, dim, output_impl);
+ }
+}
+
+}
diff --git a/onnxruntime/test/mlas/bench/bench_rope.cpp b/onnxruntime/test/mlas/bench/bench_rope.cpp
new file mode 100644
index 0000000000000..9103ba6424f65
--- /dev/null
+++ b/onnxruntime/test/mlas/bench/bench_rope.cpp
@@ -0,0 +1,52 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "mlas.h"
+#include "benchmark/benchmark.h"
+#include "bench_util.h"
+
+void RunRoPEBenchmark(size_t rotary_emb_dim, bool interleaved, benchmark::State& state) {
+ const float Pi = 2 * std::acos(0.0f);
+
+ std::vector input(rotary_emb_dim);
+ size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim;
+ std::vector sin_data(table_len);
+ std::vector cos_data(table_len);
+ std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim);
+
+ for (size_t i = 0; i < rotary_emb_dim; ++i) {
+ input[i] = static_cast(i + 1);
+ }
+ for (size_t i = 0; i < table_len; ++i) {
+ float theta = (float)i / 1000 * Pi;
+ sin_data[i] = std::sin(theta);
+ cos_data[i] = std::cos(theta);
+ }
+
+ // warm up run
+ MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]);
+
+ for (auto _ : state) {
+ MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]);
+ }
+}
+
+void RoPE(benchmark::State& state) {
+ using onnxruntime::narrow;
+
+ const auto rotary_emb_dim = narrow(state.range(0));
+ const auto interleaved = narrow(state.range(1));
+
+ RunRoPEBenchmark(rotary_emb_dim, interleaved, state);
+}
+
+static void RoPEArgs(benchmark::internal::Benchmark* b) {
+ b->ArgNames({"rotary_emb_dim", "interleaved"});
+
+ b->ArgsProduct({
+ {128, 256, 512, 1024}, // rotary_emb_dim
+ {int64_t{false}, int64_t{true}}, // interleaved
+ });
+}
+
+BENCHMARK(RoPE)->Apply(RoPEArgs)->UseRealTime();
diff --git a/onnxruntime/test/mlas/unittest/test_rope.cpp b/onnxruntime/test/mlas/unittest/test_rope.cpp
new file mode 100644
index 0000000000000..54087a933fd9e
--- /dev/null
+++ b/onnxruntime/test/mlas/unittest/test_rope.cpp
@@ -0,0 +1,129 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ test_rope.h
+
+Abstract:
+
+ Tests for MLAS RoPE.
+
+--*/
+
+#include "test_util.h"
+#include "mlas.h"
+#include "core/mlas/lib/rotary_embedding.h"
+
+class MlasRoPETest : public MlasTestBase {
+ const float Pi = 2 * std::acos(0.0f);
+
+ public:
+ void Test(size_t rotary_emb_dim, bool interleaved) {
+ std::vector input(rotary_emb_dim);
+ size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim;
+ std::vector sin_data(table_len);
+ std::vector cos_data(table_len);
+ std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim);
+
+ for (size_t i = 0; i < rotary_emb_dim; ++i) {
+ input[i] = static_cast(i + 1);
+ }
+ for (size_t i = 0; i < table_len; ++i) {
+ float theta = (float)i / 1000 * Pi;
+ sin_data[i] = std::sin(theta);
+ cos_data[i] = std::cos(theta);
+ }
+
+ // Call the function
+ MlasRotaryEmbedOneRow_FallBack(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_ref[0]);
+ MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]);
+
+ for (size_t i = 0; i < rotary_emb_dim; i++) {
+ ASSERT_TRUE(CloseEnough(output_impl[i], output_ref[i]))
+ << "Expected: " << output_ref[i] << " Actual: " << output_impl[i] << "@[" << i << "], "
+ << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved;
+ }
+ }
+
+ public:
+};
+
+//
+// Short Execute() test helper to register each test separately by all parameters.
+//
+class RoPEShortExecuteTest : public MlasTestFixture {
+ public:
+ explicit RoPEShortExecuteTest(size_t rotary_emb_dim, bool interleaved)
+ : rotary_emb_dim_(rotary_emb_dim),
+ interleaved_(interleaved) {}
+
+ void TestBody() override {
+ MlasTestFixture::mlas_tester->Test(rotary_emb_dim_, interleaved_);
+ }
+
+ static size_t RegisterSingleTest(size_t rotary_emb_dim, bool interleaved) {
+ size_t tests_registered = 0;
+
+ std::stringstream ss;
+ ss << "/rotary_emb_dim" << rotary_emb_dim << "/interleaved" << interleaved;
+ auto test_name = ss.str();
+
+ testing::RegisterTest(
+ "RoPE",
+ test_name.c_str(),
+ nullptr,
+ test_name.c_str(),
+ __FILE__,
+ __LINE__,
+ // Important to use the fixture type as the return type here.
+ [=]() -> MlasTestFixture* {
+ return new RoPEShortExecuteTest(rotary_emb_dim, interleaved);
+ });
+
+ tests_registered += 1;
+
+ return tests_registered;
+ }
+
+ static size_t RegisterShortExecuteTests() {
+ size_t tests_registered = 0;
+ tests_registered += RegisterSingleTest(6, false);
+ tests_registered += RegisterSingleTest(6, true);
+ tests_registered += RegisterSingleTest(16, false);
+ tests_registered += RegisterSingleTest(16, true);
+ tests_registered += RegisterSingleTest(24, false);
+ tests_registered += RegisterSingleTest(24, true);
+ tests_registered += RegisterSingleTest(32, false);
+ tests_registered += RegisterSingleTest(32, true);
+ tests_registered += RegisterSingleTest(42, false);
+ tests_registered += RegisterSingleTest(42, true);
+ tests_registered += RegisterSingleTest(64, false);
+ tests_registered += RegisterSingleTest(64, true);
+ tests_registered += RegisterSingleTest(70, false);
+ tests_registered += RegisterSingleTest(70, true);
+ return tests_registered;
+ }
+
+ private:
+ size_t rotary_emb_dim_;
+ bool interleaved_;
+};
+
+// only test float RoPE with avx2 where RopeDispatch is assigned at this moment.
+#ifdef MLAS_TARGET_AMD64
+static size_t RoPERegisterAllShortExecuteTests() {
+ return RoPEShortExecuteTest::RegisterShortExecuteTests();
+}
+
+static UNUSED_VARIABLE bool added_to_main = AddTestRegister(
+ [](bool is_short_execute) -> size_t {
+ if (is_short_execute) {
+ return RoPERegisterAllShortExecuteTests();
+ }
+ return 0;
+ });
+#endif