From 7db520c65da27984f200b4b84f101f8543d2813b Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 28 Aug 2023 15:17:27 +0800 Subject: [PATCH] opt wip --- src/layer/x86/convolution_im2col_gemm_int8.h | 188 ++++++++++++++++++- 1 file changed, 183 insertions(+), 5 deletions(-) diff --git a/src/layer/x86/convolution_im2col_gemm_int8.h b/src/layer/x86/convolution_im2col_gemm_int8.h index 89bcd140f38..4c1912ed5d2 100644 --- a/src/layer/x86/convolution_im2col_gemm_int8.h +++ b/src/layer/x86/convolution_im2col_gemm_int8.h @@ -4027,6 +4027,12 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M { const signed char* pA = pAT; +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; +#else __m128i _sum0; __m128i _sum1; __m128i _sum2; @@ -4035,9 +4041,16 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m128i _sum5; __m128i _sum6; __m128i _sum7; +#endif if (k == 0) { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); +#else _sum0 = _mm_setzero_si128(); _sum1 = _mm_setzero_si128(); _sum2 = _mm_setzero_si128(); @@ -4046,9 +4059,16 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _sum5 = _mm_setzero_si128(); _sum6 = _mm_setzero_si128(); _sum7 = _mm_setzero_si128(); +#endif } else { +#if __AVX2__ + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); +#else _sum0 = _mm_load_si128((const __m128i*)outptr); _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); @@ -4057,14 +4077,39 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _sum5 = _mm_load_si128((const __m128i*)(outptr + 20)); _sum6 = _mm_load_si128((const __m128i*)(outptr + 24)); _sum7 = _mm_load_si128((const __m128i*)(outptr + 28)); +#endif } int kk = 0; for (; kk + 1 < max_kk; kk += 2) { - __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); __m128i _pB = _mm_loadu_si128((const __m128i*)pB); +#if __AVX2__ + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 0123 0123 + // 2301 2301 + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + + // 0123 4567 + // 1230 5674 + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); +#endif +#else // __AVX2__ #if __SSE4_1__ _pA = _mm_cvtepi8_epi16(_pA); #else @@ -4107,6 +4152,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2)); _sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3)); #endif +#endif // __AVX2__ pA += 8; pB += 16; @@ -4124,6 +4170,27 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); #endif +#if __AVX2__ + // 01230123 + // 23012301 + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + + // 01234567 + // 12305674 + __m128i _pB0 = _pB; + __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0)); + __m256i _s1 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1)); + __m256i _s2 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0)); + __m256i _s3 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1)); + + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); +#else // __AVX2__ #if __XOP__ // 00112233 // 22330011 @@ -4147,13 +4214,13 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _sum6 = _mm_maccd_epi16(_pA1, _pB2, _sum6); _sum7 = _mm_maccd_epi16(_pA1, _pB3, _sum7); #else - // 0123 0123 - // 2301 2301 + // 01230123 + // 23012301 __m128i _pA0 = _pA; __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); - // 0123 4567 - // 1230 5674 + // 01234567 + // 12305674 __m128i _pB01 = _pB; __m128i _pB23 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); @@ -4183,6 +4250,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _sum6 = _mm_add_epi32(_sum6, _s6); _sum7 = _mm_add_epi32(_sum7, _s7); #endif +#endif // __AVX2__ pA += 4; pB += 8; @@ -4192,6 +4260,60 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M { if (out_elempack == 4) { +#if __AVX2__ + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 20 31 02 13 24 35 06 17 + // 21 32 03 10 25 36 07 14 + + __m256i _tmp0 = _sum0; + __m256i _tmp1 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _tmp2 = _sum2; + __m256i _tmp3 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + + // 00 11 22 33 04 15 26 37 + // 10 21 32 03 14 25 36 07 + // 20 31 02 13 24 35 06 17 + // 30 01 12 23 34 05 16 27 + + _sum0 = _mm256_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm256_unpackhi_epi32(_tmp0, _tmp1); + _sum2 = _mm256_unpacklo_epi32(_tmp2, _tmp3); + _sum3 = _mm256_unpackhi_epi32(_tmp2, _tmp3); + + // 00 10 11 21 04 14 15 25 + // 22 32 33 03 26 36 37 07 + // 20 30 31 01 24 34 35 05 + // 02 12 13 23 06 16 17 27 + + _tmp0 = _mm256_unpacklo_epi64(_sum0, _sum2); + _tmp1 = _mm256_unpackhi_epi64(_sum0, _sum2); + _tmp2 = _mm256_unpacklo_epi64(_sum3, _sum1); + _tmp3 = _mm256_unpackhi_epi64(_sum3, _sum1); + + // 00 10 20 30 04 14 24 34 + // 11 21 31 01 15 25 35 05 + // 02 12 22 32 06 16 26 36 + // 13 23 33 03 17 27 37 07 + + _tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp3 = _mm256_shuffle_epi32(_tmp3, _MM_SHUFFLE(2, 1, 0, 3)); + + // 00 10 20 30 04 14 24 34 + // 01 11 21 31 05 15 25 35 + // 02 12 22 32 06 16 26 36 + // 03 13 23 33 07 17 27 37 + + _sum0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _sum1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); + _sum2 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); + _sum3 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_si256((__m256i*)outptr0, _sum0); + _mm256_storeu_si256((__m256i*)(outptr0 + 8), _sum1); + _mm256_storeu_si256((__m256i*)(outptr0 + 16), _sum2); + _mm256_storeu_si256((__m256i*)(outptr0 + 24), _sum3); +#else // 00 11 22 33 // 04 15 26 37 // 01 12 23 30 @@ -4251,10 +4373,58 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _mm_store_si128((__m128i*)(outptr0 + 20), _sum5); _mm_store_si128((__m128i*)(outptr0 + 24), _sum6); _mm_store_si128((__m128i*)(outptr0 + 28), _sum7); +#endif // __AVX2__ outptr0 += 32; } if (out_elempack == 1) { +#if __AVX2__ + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 20 31 02 13 24 35 06 17 + // 21 32 03 10 25 36 07 14 + + _sum2 = _mm256_shuffle_epi32(_sum2, _MM_SHUFFLE(1, 0, 3, 2)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(1, 0, 3, 2)); + + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 02 13 20 31 06 17 24 35 + // 03 10 21 32 07 14 25 36 + + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum1); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum3); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum3); + + // 00 01 11 12 04 05 15 16 + // 22 23 33 30 26 27 37 34 + // 02 03 13 10 06 07 17 14 + // 20 21 31 32 24 25 35 36 + + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + + // 00 01 02 03 04 05 06 07 + // 11 12 13 10 15 16 17 14 + // 20 21 22 23 24 25 26 27 + // 31 32 33 30 35 36 37 34 + + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + // 20 21 22 23 24 25 26 27 + // 30 31 32 33 34 35 36 37 + + _mm256_storeu_si256((__m256i*)outptr0, _sum0); + _mm256_storeu_si256((__m256i*)(outptr0 + out_hstep), _sum1); + _mm256_storeu_si256((__m256i*)(outptr0 + out_hstep * 2), _sum2); + _mm256_storeu_si256((__m256i*)(outptr0 + out_hstep * 3), _sum3); +#else // 00 11 22 33 // 04 15 26 37 // 01 12 23 30 @@ -4314,11 +4484,18 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 2 + 4), _sum6); _mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 3), _sum3); _mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 3 + 4), _sum7); +#endif // __AVX2__ outptr0 += 8; } } else { +#if __AVX2__ + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 16), _sum2); + _mm256_store_si256((__m256i*)(outptr + 24), _sum3); +#else _mm_store_si128((__m128i*)outptr, _sum0); _mm_store_si128((__m128i*)(outptr + 4), _sum1); _mm_store_si128((__m128i*)(outptr + 8), _sum2); @@ -4327,6 +4504,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _mm_store_si128((__m128i*)(outptr + 20), _sum5); _mm_store_si128((__m128i*)(outptr + 24), _sum6); _mm_store_si128((__m128i*)(outptr + 28), _sum7); +#endif } outptr += 32;