Skip to content

Commit

Permalink
opt wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Aug 28, 2023
1 parent 4d28d61 commit 7db520c
Showing 1 changed file with 183 additions and 5 deletions.
188 changes: 183 additions & 5 deletions src/layer/x86/convolution_im2col_gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -4027,6 +4027,12 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
{
const signed char* pA = pAT;

#if __AVX2__
__m256i _sum0;
__m256i _sum1;
__m256i _sum2;
__m256i _sum3;
#else
__m128i _sum0;
__m128i _sum1;
__m128i _sum2;
Expand All @@ -4035,9 +4041,16 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
__m128i _sum5;
__m128i _sum6;
__m128i _sum7;
#endif

if (k == 0)
{
#if __AVX2__
_sum0 = _mm256_setzero_si256();
_sum1 = _mm256_setzero_si256();
_sum2 = _mm256_setzero_si256();
_sum3 = _mm256_setzero_si256();
#else
_sum0 = _mm_setzero_si128();
_sum1 = _mm_setzero_si128();
_sum2 = _mm_setzero_si128();
Expand All @@ -4046,9 +4059,16 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_sum5 = _mm_setzero_si128();
_sum6 = _mm_setzero_si128();
_sum7 = _mm_setzero_si128();
#endif
}
else
{
#if __AVX2__
_sum0 = _mm256_load_si256((const __m256i*)outptr);
_sum1 = _mm256_load_si256((const __m256i*)(outptr + 8));
_sum2 = _mm256_load_si256((const __m256i*)(outptr + 16));
_sum3 = _mm256_load_si256((const __m256i*)(outptr + 24));
#else
_sum0 = _mm_load_si128((const __m128i*)outptr);
_sum1 = _mm_load_si128((const __m128i*)(outptr + 4));
_sum2 = _mm_load_si128((const __m128i*)(outptr + 8));
Expand All @@ -4057,14 +4077,39 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_sum5 = _mm_load_si128((const __m128i*)(outptr + 20));
_sum6 = _mm_load_si128((const __m128i*)(outptr + 24));
_sum7 = _mm_load_si128((const __m128i*)(outptr + 28));
#endif
}

int kk = 0;
for (; kk + 1 < max_kk; kk += 2)
{
__m128i _pA = _mm_loadl_epi64((const __m128i*)pA);
__m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA));
__m128i _pB = _mm_loadu_si128((const __m128i*)pB);

#if __AVX2__
__m256i _pA0 = _mm256_cvtepi8_epi16(_pA);
__m256i _pB0 = _mm256_cvtepi8_epi16(_pB);

// 0123 0123
// 2301 2301
__m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2));

// 0123 4567
// 1230 5674
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));

#if __AVXVNNI__ || __AVX512VNNI__
_sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0);
_sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1);
_sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0);
_sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1);
#else
_sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0));
_sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1));
_sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0));
_sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1));
#endif
#else // __AVX2__
#if __SSE4_1__
_pA = _mm_cvtepi8_epi16(_pA);
#else
Expand Down Expand Up @@ -4107,6 +4152,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2));
_sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3));
#endif
#endif // __AVX2__

pA += 8;
pB += 16;
Expand All @@ -4124,6 +4170,27 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB));
#endif

#if __AVX2__
// 01230123
// 23012301
__m128i _pA0 = _pA;
__m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1));

// 01234567
// 12305674
__m128i _pB0 = _pB;
__m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1));

__m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0));
__m256i _s1 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1));
__m256i _s2 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0));
__m256i _s3 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1));

_sum0 = _mm256_add_epi32(_sum0, _s0);
_sum1 = _mm256_add_epi32(_sum1, _s1);
_sum2 = _mm256_add_epi32(_sum2, _s2);
_sum3 = _mm256_add_epi32(_sum3, _s3);
#else // __AVX2__
#if __XOP__
// 00112233
// 22330011
Expand All @@ -4147,13 +4214,13 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_sum6 = _mm_maccd_epi16(_pA1, _pB2, _sum6);
_sum7 = _mm_maccd_epi16(_pA1, _pB3, _sum7);
#else
// 0123 0123
// 2301 2301
// 01230123
// 23012301
__m128i _pA0 = _pA;
__m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1));

// 0123 4567
// 1230 5674
// 01234567
// 12305674
__m128i _pB01 = _pB;
__m128i _pB23 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1));

Expand Down Expand Up @@ -4183,6 +4250,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_sum6 = _mm_add_epi32(_sum6, _s6);
_sum7 = _mm_add_epi32(_sum7, _s7);
#endif
#endif // __AVX2__

pA += 4;
pB += 8;
Expand All @@ -4192,6 +4260,60 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
{
if (out_elempack == 4)
{
#if __AVX2__
// 00 11 22 33 04 15 26 37
// 01 12 23 30 05 16 27 34
// 20 31 02 13 24 35 06 17
// 21 32 03 10 25 36 07 14

__m256i _tmp0 = _sum0;
__m256i _tmp1 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));
__m256i _tmp2 = _sum2;
__m256i _tmp3 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));

// 00 11 22 33 04 15 26 37
// 10 21 32 03 14 25 36 07
// 20 31 02 13 24 35 06 17
// 30 01 12 23 34 05 16 27

_sum0 = _mm256_unpacklo_epi32(_tmp0, _tmp1);
_sum1 = _mm256_unpackhi_epi32(_tmp0, _tmp1);
_sum2 = _mm256_unpacklo_epi32(_tmp2, _tmp3);
_sum3 = _mm256_unpackhi_epi32(_tmp2, _tmp3);

// 00 10 11 21 04 14 15 25
// 22 32 33 03 26 36 37 07
// 20 30 31 01 24 34 35 05
// 02 12 13 23 06 16 17 27

_tmp0 = _mm256_unpacklo_epi64(_sum0, _sum2);
_tmp1 = _mm256_unpackhi_epi64(_sum0, _sum2);
_tmp2 = _mm256_unpacklo_epi64(_sum3, _sum1);
_tmp3 = _mm256_unpackhi_epi64(_sum3, _sum1);

// 00 10 20 30 04 14 24 34
// 11 21 31 01 15 25 35 05
// 02 12 22 32 06 16 26 36
// 13 23 33 03 17 27 37 07

_tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(2, 1, 0, 3));
_tmp3 = _mm256_shuffle_epi32(_tmp3, _MM_SHUFFLE(2, 1, 0, 3));

// 00 10 20 30 04 14 24 34
// 01 11 21 31 05 15 25 35
// 02 12 22 32 06 16 26 36
// 03 13 23 33 07 17 27 37

_sum0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0));
_sum1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0));
_sum2 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1));
_sum3 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1));

_mm256_storeu_si256((__m256i*)outptr0, _sum0);
_mm256_storeu_si256((__m256i*)(outptr0 + 8), _sum1);
_mm256_storeu_si256((__m256i*)(outptr0 + 16), _sum2);
_mm256_storeu_si256((__m256i*)(outptr0 + 24), _sum3);
#else
// 00 11 22 33
// 04 15 26 37
// 01 12 23 30
Expand Down Expand Up @@ -4251,10 +4373,58 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_mm_store_si128((__m128i*)(outptr0 + 20), _sum5);
_mm_store_si128((__m128i*)(outptr0 + 24), _sum6);
_mm_store_si128((__m128i*)(outptr0 + 28), _sum7);
#endif // __AVX2__
outptr0 += 32;
}
if (out_elempack == 1)
{
#if __AVX2__
// 00 11 22 33 04 15 26 37
// 01 12 23 30 05 16 27 34
// 20 31 02 13 24 35 06 17
// 21 32 03 10 25 36 07 14

_sum2 = _mm256_shuffle_epi32(_sum2, _MM_SHUFFLE(1, 0, 3, 2));
_sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(1, 0, 3, 2));

// 00 11 22 33 04 15 26 37
// 01 12 23 30 05 16 27 34
// 02 13 20 31 06 17 24 35
// 03 10 21 32 07 14 25 36

__m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum1);
__m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum1);
__m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum3);
__m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum3);

// 00 01 11 12 04 05 15 16
// 22 23 33 30 26 27 37 34
// 02 03 13 10 06 07 17 14
// 20 21 31 32 24 25 35 36

_sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2);
_sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2);
_sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1);
_sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1);

// 00 01 02 03 04 05 06 07
// 11 12 13 10 15 16 17 14
// 20 21 22 23 24 25 26 27
// 31 32 33 30 35 36 37 34

_sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
_sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));

// 00 01 02 03 04 05 06 07
// 10 11 12 13 14 15 16 17
// 20 21 22 23 24 25 26 27
// 30 31 32 33 34 35 36 37

_mm256_storeu_si256((__m256i*)outptr0, _sum0);
_mm256_storeu_si256((__m256i*)(outptr0 + out_hstep), _sum1);
_mm256_storeu_si256((__m256i*)(outptr0 + out_hstep * 2), _sum2);
_mm256_storeu_si256((__m256i*)(outptr0 + out_hstep * 3), _sum3);
#else
// 00 11 22 33
// 04 15 26 37
// 01 12 23 30
Expand Down Expand Up @@ -4314,11 +4484,18 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 2 + 4), _sum6);
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 3), _sum3);
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 3 + 4), _sum7);
#endif // __AVX2__
outptr0 += 8;
}
}
else
{
#if __AVX2__
_mm256_store_si256((__m256i*)outptr, _sum0);
_mm256_store_si256((__m256i*)(outptr + 8), _sum1);
_mm256_store_si256((__m256i*)(outptr + 16), _sum2);
_mm256_store_si256((__m256i*)(outptr + 24), _sum3);
#else
_mm_store_si128((__m128i*)outptr, _sum0);
_mm_store_si128((__m128i*)(outptr + 4), _sum1);
_mm_store_si128((__m128i*)(outptr + 8), _sum2);
Expand All @@ -4327,6 +4504,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_mm_store_si128((__m128i*)(outptr + 20), _sum5);
_mm_store_si128((__m128i*)(outptr + 24), _sum6);
_mm_store_si128((__m128i*)(outptr + 28), _sum7);
#endif
}

outptr += 32;
Expand Down

0 comments on commit 7db520c

Please sign in to comment.