Skip to content

Commit

Permalink
Enable AVX NE CONVERT for FP16 to FP32 cast (#21183)
Browse files Browse the repository at this point in the history
### Description
Implementation of a new cast assembly kernel that uses AVX_NE_CONVERT
instructions to accelerate casting from FP16 to FP32. Added CPUID checks
to determine support of the ISA.

### Motivation and Context
Currently FP16 models executed on systems that lack complete FP16
operator support use single precision on every node to run the model,
this means the original FP16 weights have to be casted to FP32 in order
to run the model properly, this change aims to accelerate the casting by
using upconvert instructions and therefore improve performance.
  • Loading branch information
eralmual authored Sep 10, 2024
1 parent d4d419f commit 7489bfe
Show file tree
Hide file tree
Showing 10 changed files with 537 additions and 11 deletions.
19 changes: 19 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
)

target_sources(onnxruntime_mlas PRIVATE
Expand Down Expand Up @@ -212,6 +213,12 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm
${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm
)
if(MSVC_VERSION GREATER_EQUAL 1933)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm
)
endif()

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
Expand Down Expand Up @@ -522,6 +529,12 @@ else()
${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S
${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S
)
if(NOT APPLE)
set(mlas_platform_srcs_sse2
${mlas_platform_srcs_sse2}
${MLAS_SRC_DIR}/x86_64/cvtfp16a.S
)
endif()
set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2")

set(mlas_platform_srcs_avx
Expand Down Expand Up @@ -555,6 +568,12 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE))
set(mlas_platform_srcs_avx2
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S
)
endif()

message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}")
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,14 +1029,13 @@ MlasComputeTanh(
// Half-precision floating-point routines.
//

extern "C"
void
MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
float* Destination,
size_t Count
);
);

//
// Transpose routines.
Expand Down
151 changes: 151 additions & 0 deletions onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
;++
;
; Copyright (c) Intel Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; cvtfp16Avx2.asm
;
; Abstract:
;
; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA.
;
;--

.xlist
INCLUDE mlasi.inc
.list

.const

SINGLE_SIZE equ 4
HALF_SIZE equ 2
LOW_SELECTOR equ 00100000b
HIGH_SELECTOR equ 00110001b

SUBTTL "Convert buffer of half-precision floats to single-precision floats"
;++
;
; Routine Description:
;
; This routine converts the source buffer of half-precision floats to the
; destination buffer of single-precision floats.
;
; This implementation uses AVX2 instructions.
;
; Arguments:
;
; Source (rcx) - Supplies the address of the source buffer of half-precision
; floats.
;
; Destination (rdx) - Supplies the address of the destination buffer of
; single-precision floats.
;
; Count (r8) - Supplies the number of elements to convert.
;
; Return Value:
;
; None.
;
;--


LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT

test r8, r8 ; Check if we have any elements to convert
jz ExitRoutine
cmp r8, 8
jb ConvertMaskedVectors
cmp r8, 16
jb Convert128Vectors



Convert256Vectors:
vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes
vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes
vunpcklps ymm2, ymm0, ymm1 ; Interleave low part
vunpckhps ymm1, ymm0, ymm1 ; Interleave high part
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order
vmovups ymmword PTR [rdx], ymm0 ; Store the low part
vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part

add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements
add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements
sub r8, 16 ; Reduce the counter by 16 elements

jz ExitRoutine ; If we are done, exit
cmp r8, 16 ; If the vector is big enough, we go again
jae Convert256Vectors



Convert128Vectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part

add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements
add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements
sub r8, 8 ; Reduce the counter by 8 elements

jz ExitRoutine ; If we are done, exit



ConvertMaskedVectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order

cmp r8, 4 ; Check if we can store the complete lower vector
jae ConvertLowerVector

vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertLower1
ja ConvertLower3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertLowerMaskedVector
ConvertLower1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertLowerMaskedVector
ConvertLower3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertLowerMaskedVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing
ConvertLowerVector:
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
sub r8, 4 ; Check if we still need to convert
jz ExitRoutine


add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements
vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertUpper1
ja ConvertUpper3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertMaskedUpperVector
ConvertUpper1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertMaskedUpperVector
ConvertUpper3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertMaskedUpperVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples

ExitRoutine:
ret

LEAF_END MlasCastF16ToF32KernelAvx, _TEXT

END
6 changes: 3 additions & 3 deletions onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h)
; Source (rcx) - Supplies the address of the source buffer of half-precision
; floats.
;
; Destination (edx) - Supplies the address of the destination buffer of
; Destination (rdx) - Supplies the address of the destination buffer of
; single-precision floats.
;
; Count (r8) - Supplies the number of elements to convert.
Expand All @@ -53,7 +53,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h)
;
;--

LEAF_ENTRY MlasConvertHalfToFloatBuffer, _TEXT
LEAF_ENTRY MlasCastF16ToF32KernelSse, _TEXT

test r8,r8
jz ExitRoutine
Expand Down Expand Up @@ -119,6 +119,6 @@ StoreLastElement:
ExitRoutine:
ret

LEAF_END MlasConvertHalfToFloatBuffer, _TEXT
LEAF_END MlasCastF16ToF32KernelSse, _TEXT

END
59 changes: 59 additions & 0 deletions onnxruntime/core/mlas/lib/cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*++
Copyright (c) Intel Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
cast.cpp
Abstract:
This module implements Half (F16) to Single (F32) precision casting.
--*/
#include "mlasi.h"

union fp32_bits {
uint32_t u;
float f;
};

void
MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
float* Destination,
size_t Count
)
{

if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) {
// If there is no kernel use the reference implementation, adapted from mlas_float16.h.
constexpr fp32_bits magic = {113 << 23};
constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift

for (size_t i = 0; i < Count; ++i) {
fp32_bits o;
o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits
uint32_t exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust

// handle exponent special cases
if (exp == shifted_exp) { // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
} else if (exp == 0) { // Zero/Denormal?
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // renormalize
}

o.u |= (Source[i] & 0x8000) << 16; // sign bit
Destination[i] = o.f;
}

} else {
// If the kernel is available, use it to perform the conversion.
GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count);
}
}
14 changes: 14 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,13 @@ void
size_t N
);

typedef
void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)(
const unsigned short* Source,
float* Destination,
size_t Count
);

typedef
void
(MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)(
Expand Down Expand Up @@ -870,6 +877,11 @@ extern "C" {
MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32KernelAvx;
#endif

#if defined(MLAS_TARGET_AMD64)
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse;
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx;
#endif

}

//
Expand Down Expand Up @@ -1151,6 +1163,8 @@ struct MLAS_PLATFORM {
const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr};

const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr};

MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
};

inline
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ Return Value:
this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernel<uint8_t, uint8_t>;
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel<int8_t, int8_t>;
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel<int8_t, uint8_t>;
this->CastF16ToF32Kernel = nullptr;

#if defined(MLAS_TARGET_AMD64_IX86)

Expand Down Expand Up @@ -283,6 +284,9 @@ Return Value:
this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel;
this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel;
this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel;
#ifndef __APPLE__
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse;
#endif // __APPLE__

this->NchwcBlockSize = 8;
this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT;
Expand Down Expand Up @@ -469,6 +473,16 @@ Return Value:
}

#ifndef __APPLE__
#if (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13))
//
// Check if the processor supports AVX NE CONVERT.
//
if ((Cpuid7_1[3] & (0b1 << 5)) != 0) {
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx;
}
#endif // (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13))


//
// Check if the processor supports AMX-TILE and AMX-INT8
// features.
Expand Down
Loading

0 comments on commit 7489bfe

Please sign in to comment.