Skip to content

Commit

Permalink
comp avxvnni
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Nov 29, 2024
1 parent f32b4b4 commit 15ebb61
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 79 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ else()
check_cxx_compiler_flag("/arch:AVX512" NCNN_COMPILER_SUPPORT_X86_AVX512)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)
Expand Down Expand Up @@ -541,7 +541,7 @@ else()
check_cxx_compiler_flag("/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl" NCNN_COMPILER_SUPPORT_X86_AVX512)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)
Expand Down
78 changes: 39 additions & 39 deletions src/layer/x86/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -18490,14 +18490,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
__m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3));
_sum0 = _mm256_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm256_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm256_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm256_dpbusd_epi32(_sum3, _pB1, _pA1);
_sum4 = _mm256_dpbusd_epi32(_sum4, _pB2, _pA0);
_sum5 = _mm256_dpbusd_epi32(_sum5, _pB3, _pA0);
_sum6 = _mm256_dpbusd_epi32(_sum6, _pB2, _pA1);
_sum7 = _mm256_dpbusd_epi32(_sum7, _pB3, _pA1);
_sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm256_comp_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm256_comp_dpbusd_epi32(_sum3, _pB1, _pA1);
_sum4 = _mm256_comp_dpbusd_epi32(_sum4, _pB2, _pA0);
_sum5 = _mm256_comp_dpbusd_epi32(_sum5, _pB3, _pA0);
_sum6 = _mm256_comp_dpbusd_epi32(_sum6, _pB2, _pA1);
_sum7 = _mm256_comp_dpbusd_epi32(_sum7, _pB3, _pA1);
pA += 32;
pB += 32;
}
Expand Down Expand Up @@ -18646,10 +18646,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1);
__m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
_sum0 = _mm256_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm256_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm256_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm256_dpbusd_epi32(_sum3, _pB1, _pA1);
_sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm256_comp_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm256_comp_dpbusd_epi32(_sum3, _pB1, _pA1);
pA += 32;
pB += 16;
}
Expand Down Expand Up @@ -18752,8 +18752,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m256i _pA = _mm256_loadu_si256((const __m256i*)pA);
__m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1));
_sum0 = _mm256_dpbusd_epi32(_sum0, _pB0, _pA);
_sum1 = _mm256_dpbusd_epi32(_sum1, _pB1, _pA);
_sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA);
_sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA);
pA += 32;
pB += 8;
}
Expand Down Expand Up @@ -18836,7 +18836,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
__m256i _pA = _mm256_loadu_si256((const __m256i*)pA);
__m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB));
_sum0 = _mm256_dpbusd_epi32(_sum0, _pB, _pA);
_sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB, _pA);
pA += 32;
pB += 4;
}
Expand Down Expand Up @@ -19057,14 +19057,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2));
__m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
__m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1));
_sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm_dpbusd_epi32(_sum3, _pB1, _pA1);
_sum4 = _mm_dpbusd_epi32(_sum4, _pB2, _pA0);
_sum5 = _mm_dpbusd_epi32(_sum5, _pB3, _pA0);
_sum6 = _mm_dpbusd_epi32(_sum6, _pB2, _pA1);
_sum7 = _mm_dpbusd_epi32(_sum7, _pB3, _pA1);
_sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm_comp_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm_comp_dpbusd_epi32(_sum3, _pB1, _pA1);
_sum4 = _mm_comp_dpbusd_epi32(_sum4, _pB2, _pA0);
_sum5 = _mm_comp_dpbusd_epi32(_sum5, _pB3, _pA0);
_sum6 = _mm_comp_dpbusd_epi32(_sum6, _pB2, _pA1);
_sum7 = _mm_comp_dpbusd_epi32(_sum7, _pB3, _pA1);
pA += 16;
pB += 32;
}
Expand Down Expand Up @@ -19255,10 +19255,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m128i _pB0 = _mm_loadu_si128((const __m128i*)pB);
__m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2));
__m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
_sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm_dpbusd_epi32(_sum3, _pB1, _pA1);
_sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm_comp_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm_comp_dpbusd_epi32(_sum3, _pB1, _pA1);
pA += 16;
pB += 16;
}
Expand Down Expand Up @@ -19399,8 +19399,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m128i _pA = _mm_loadu_si128((const __m128i*)pA);
__m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB));
__m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1));
_sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA);
_sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA);
_sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA);
_sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA);
pA += 16;
pB += 8;
}
Expand Down Expand Up @@ -19511,7 +19511,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
__m128i _pA = _mm_loadu_si128((const __m128i*)pA);
__m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB));
_sum0 = _mm_dpbusd_epi32(_sum0, _pB, _pA);
_sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB, _pA);
pA += 16;
pB += 4;
}
Expand Down Expand Up @@ -19711,10 +19711,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1));
__m128i _pB0 = _mm_loadu_si128((const __m128i*)pB);
__m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 16));
_sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm_dpbusd_epi32(_sum3, _pB1, _pA1);
_sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA0);
_sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA0);
_sum2 = _mm_comp_dpbusd_epi32(_sum2, _pB0, _pA1);
_sum3 = _mm_comp_dpbusd_epi32(_sum3, _pB1, _pA1);
pA += 8;
pB += 32;
}
Expand Down Expand Up @@ -19837,8 +19837,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA));
__m128i _pB0 = _mm_loadu_si128((const __m128i*)pB);
__m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
_sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA);
_sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA);
_sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA);
_sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA);
pA += 8;
pB += 16;
}
Expand Down Expand Up @@ -20177,8 +20177,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA));
__m128i _pB0 = _mm_loadu_si128((const __m128i*)pB);
__m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 16));
_sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA);
_sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA);
_sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA);
_sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA);
pA += 4;
pB += 32;
}
Expand Down Expand Up @@ -20265,7 +20265,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
__m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA));
__m128i _pB = _mm_loadu_si128((const __m128i*)pB);
_sum0 = _mm_dpbusd_epi32(_sum0, _pB, _pA);
_sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB, _pA);
pA += 4;
pB += 16;
}
Expand Down
96 changes: 58 additions & 38 deletions src/layer/x86/x86_usability.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,83 +267,83 @@ static NCNN_FORCEINLINE __m128i float2bfloat_sse(const __m128& v0, const __m128&
return _v;
}

#ifndef __FMA__
static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c)
{
return _mm_add_ps(_mm_mul_ps(_a, _b), _c);
}
static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c)
{
return _mm_sub_ps(_c, _mm_mul_ps(_a, _b));
}
static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c)
{
return _mm_sub_ps(_mm_mul_ps(_a, _b), _c);
}
static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c)
{
return _mm_sub_ps(_c, _mm_mul_ps(_mm_mul_ps(_a, _b), _mm_set1_ps(-1)));
}
#else
static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c)
{
#if __FMA__
return _mm_fmadd_ps(_a, _b, _c);
#else
return _mm_add_ps(_mm_mul_ps(_a, _b), _c);
#endif
}

static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c)
{
// return -a * b + c
#if __FMA__
return _mm_fnmadd_ps(_a, _b, _c);
#else
return _mm_sub_ps(_c, _mm_mul_ps(_a, _b));
#endif
}

static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c)
{
#if __FMA__
return _mm_fmsub_ps(_a, _b, _c);
#else
return _mm_sub_ps(_mm_mul_ps(_a, _b), _c);
#endif
}

static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c)
{
#if __FMA__
return _mm_fnmsub_ps(_a, _b, _c);
#else
return _mm_sub_ps(_c, _mm_mul_ps(_mm_mul_ps(_a, _b), _mm_set1_ps(-1)));
#endif
}
#endif // !__FMA__

#if __AVX__
#ifndef __FMA__
static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c)
{
return _mm256_add_ps(_mm256_mul_ps(_a, _b), _c);
}
static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c)
{
return _mm256_sub_ps(_c, _mm256_mul_ps(_a, _b));
}
static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c)
{
return _mm256_sub_ps(_mm256_mul_ps(_a, _b), _c);
}
static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c)
{
return _mm256_sub_ps(_c, _mm256_mul_ps(_mm256_mul_ps(_a, _b), _mm256_set1_ps(-1)));
}
#else
static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c)
{
// return a * b + c
#if __FMA__
return _mm256_fmadd_ps(_a, _b, _c);
#else
return _mm256_add_ps(_mm256_mul_ps(_a, _b), _c);
#endif
}

static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c)
{
// return -a * b + c
#if __FMA__
return _mm256_fnmadd_ps(_a, _b, _c);
#else
return _mm256_sub_ps(_c, _mm256_mul_ps(_a, _b));
#endif
}

static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c)
{
// return a * b - c
#if __FMA__
return _mm256_fmsub_ps(_a, _b, _c);
#else
return _mm256_sub_ps(_mm256_mul_ps(_a, _b), _c);
#endif
}

static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c)
{
// return -(a * b) - c
#if __FMA__
return _mm256_fnmsub_ps(_a, _b, _c);
}
#else
return _mm256_sub_ps(_c, _mm256_mul_ps(_mm256_mul_ps(_a, _b), _mm256_set1_ps(-1)));
#endif
}

static NCNN_FORCEINLINE __m256 _mm256_fmadd_1_ps(const __m256& a, const __m256& b, float c)
{
Expand Down Expand Up @@ -841,6 +841,26 @@ static NCNN_FORCEINLINE __m256i float2bfloat_avx(const __m256& v0, const __m256&
}

#if __AVX2__
#if __AVX512VNNI__ || __AVXVNNI__
static NCNN_FORCEINLINE __m128i _mm_comp_dpbusd_epi32(__m128i src, __m128i a, __m128i b)
{
#if __AVX512VNNI__
return _mm_dpbusd_epi32(src, a, b);
#else
return _mm_dpbusd_avx_epi32(src, a, b);
#endif
}

static NCNN_FORCEINLINE __m256i _mm256_comp_dpbusd_epi32(__m256i src, __m256i a, __m256i b)
{
#if __AVX512VNNI__
return _mm256_dpbusd_epi32(src, a, b);
#else
return _mm256_dpbusd_avx_epi32(src, a, b);
#endif
}
#endif // __AVX512VNNI__ || __AVXVNNI__

static NCNN_FORCEINLINE void transpose8x2_epi32(__m256i& _r0, __m256i& _r1)
{
__m256i _tmp0 = _mm256_unpacklo_epi32(_r0, _r1);
Expand Down

0 comments on commit 15ebb61

Please sign in to comment.