Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rope imbedding kernel to use avx2 #23694

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/dgemm.cpp
${mlas_platform_srcs_avx}
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2_fp32.cpp
${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
Expand Down Expand Up @@ -226,6 +229,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm
${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm
)

if(MSVC_VERSION GREATER_EQUAL 1933)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm
Expand Down Expand Up @@ -594,6 +598,9 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2_fp32.cpp
)
if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE))
set(mlas_platform_srcs_avx2
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,7 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
//
struct MLAS_ROPE_DISPATCH;
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2;

//
// half gemm dispatch structure
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ Return Value:
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2;
this->RopeDispatch = &MlasRopeDispatchAvx2;


//
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*++

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

rotary_embedding_kernel_avx2.cpp

Abstract:

This module implements the rotary embedding kernels for AVX2 supported h/w.

--*/

#include "rotary_embedding.h"
#include "rotary_embedding_kernel_avx2.h"

//
// Kernel dispatch structure definition.
//
const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2 = []() {
MLAS_ROPE_DISPATCH d;
d.SRope = rope_avx2::RopeKernel_Avx2;
return d;
}();
37 changes: 37 additions & 0 deletions onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*++

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

rotary_embedding_kernel_avx2.h

Abstract:

This module includes function declarations and common helper functions for
rotary embedding on for AVX2 enabled h/w.

--*/

#pragma once



#include "mlasi.h"

namespace rope_avx2 {

// Rotary embedding kernel for FP32. Embed one hidden state vector.
void
RopeKernel_Avx2(
const float* input,
const float* sin,
const float* cos,
size_t dim,
bool interleaved,
float* output
);

} // namespace rope_avx2
166 changes: 166 additions & 0 deletions onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*++

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

rotary_embedding_kernel_avx2_fp32.cpp

Abstract:

This module implements the fp32 rotary embedding kernels using AVX2.

--*/

#include <cassert>

#include "rotary_embedding.h"
#include "rotary_embedding_kernel_avx2.h"

namespace rope_avx2 {

namespace {

typedef __m256 float32x8_t;

template <bool interleaved>
void
RopeKernel_Avx2_Impl(
const float* input,
const float* sin,
const float* cos,
size_t dim,
float* output
);

template <>
void
RopeKernel_Avx2_Impl<false>(
const float* input,
const float* sin,
const float* cos,
size_t dim,
float* output
) {
const size_t half_dim = dim >> 1;
size_t i = 0, j = half_dim;
for (; i + 7 < half_dim; i += 8, j += 8) {
float32x8_t real = _mm256_loadu_ps(input + i);
float32x8_t imag = _mm256_loadu_ps(input + j);
float32x8_t sin_val = _mm256_loadu_ps(sin + i);
float32x8_t cos_val = _mm256_loadu_ps(cos + i);
//Compute Real and Imaginary output values
float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
//Store back into non interleaved format
_mm256_store_ps(output + i, real_out);
_mm256_store_ps(output + j, imag_out);
}
if (half_dim - i != 0) {
size_t rem = half_dim - i;
static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0};
const __m256i mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - rem));
//Use a mask to load the remaining input values
float32x8_t real = _mm256_maskload_ps(input, mask);
float32x8_t imag = _mm256_maskload_ps(input + j, mask);
float32x8_t sin_val = _mm256_maskload_ps(sin + i, mask);
float32x8_t cos_val = _mm256_maskload_ps(cos + i, mask);
//Compute Real and Imaginary output values
float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
//Store back into non interleaved format
_mm256_maskstore_ps(output + i, mask, real_out);
_mm256_maskstore_ps(output + j, mask, imag_out);
}
}

template <>
void
RopeKernel_Avx2_Impl<true>(
const float* input,
const float* sin,
const float* cos,
size_t dim,
float* output
) {
size_t i = 0;
for (; i + 15 < dim; i += 16) {
float32x8_t x0 = _mm256_loadu_ps(input + i);
float32x8_t x1 = _mm256_loadu_ps(input + i + 8);
//Load imaginary and real values to separate non-interleaved vectors
float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000);
float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101);
__m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec);
float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec);
float32x8_t sin_val = _mm256_loadu_ps(sin + i / 2);
float32x8_t cos_val = _mm256_loadu_ps(cos + i / 2);
//Compute Real and Imaginary output values
float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
//Store back into interleaved format
__m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec);
float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec);
float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s);
float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s);
_mm256_store_ps(output + i, y0);
_mm256_store_ps(output + i + 8, y1);
}
if (dim - i != 0) {
size_t rem = dim - i;
static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0};
const __m256i mask0 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?8:rem)));
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0)));
float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask
float32x8_t x1 = _mm256_maskload_ps(input + i + 8, mask1); //Load the reminder of data using a second mask
//Load imaginary and real values to separate non-interleaved vectors
float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000);
float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101);
__m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec);
float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec);
float32x8_t sin_val = _mm256_loadu_ps(sin + i / 2);
float32x8_t cos_val = _mm256_loadu_ps(cos + i / 2);
//Compute Real and Imaginary output values
float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
//Store back into interleaved format
__m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec);
float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec);
float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s);
float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s);
_mm256_maskstore_ps(output + i, mask0, y0);
_mm256_maskstore_ps(output + i + 8, mask1, y1);
}
}

} // namespace

void
RopeKernel_Avx2(
const float* input,
const float* sin,
const float* cos,
size_t dim,
bool interleaved,
float* output
) {
// real part and imaginary part must be paired
assert(dim % 2 == 0);
const auto* input_impl = reinterpret_cast<const float*>(input);
const auto* sin_impl = reinterpret_cast<const float*>(sin);
const auto* cos_impl = reinterpret_cast<const float*>(cos);
auto* output_impl = reinterpret_cast<float*>(output);

if (interleaved) {
RopeKernel_Avx2_Impl<true>(input_impl, sin_impl, cos_impl, dim, output_impl);
} else {
RopeKernel_Avx2_Impl<false>(input_impl, sin_impl, cos_impl, dim, output_impl);
}
}

}
Loading