Skip to content

Commit

Permalink
opt wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Aug 28, 2023
1 parent 7db520c commit 41f720c
Showing 1 changed file with 157 additions and 67 deletions.
224 changes: 157 additions & 67 deletions src/layer/x86/convolution_im2col_gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -3411,25 +3411,23 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
__m256i _pB0 = _mm256_cvtepi8_epi16(_pB);

// 0123 4567
// 2301 6745
__m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2));

// 0123 0123
// 3012 3012
// 2301 2301
// 1230 1230
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3));
__m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));

#if __AVXVNNI__ || __AVX512VNNI__
_sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0);
_sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1);
_sum2 = _mm256_dpwssd_epi32(_sum2, _pA0, _pB2);
_sum3 = _mm256_dpwssd_epi32(_sum3, _pA0, _pB3);
_sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0);
_sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1);
#else
_sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0));
_sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1));
_sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA0, _pB2));
_sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA0, _pB3));
_sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0));
_sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1));
#endif

pA += 16;
Expand All @@ -3444,20 +3442,19 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
_pB = _mm_cvtepi8_epi16(_pB);

// 01234567
// 23016745
__m128i _pA0 = _pA;
__m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1));

// 01230123
// 30123012
// 23012301
// 12301230
__m128i _pB0 = _pB;
__m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(2, 1, 0, 3)), _MM_SHUFFLE(2, 1, 0, 3));
__m128i _pB2 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(2, 3, 0, 1));
__m128i _pB3 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1));
__m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1));

__m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB0));
__m256i _s1 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB1));
__m256i _s2 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB2));
__m256i _s3 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB3));
__m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0));
__m256i _s1 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1));
__m256i _s2 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0));
__m256i _s3 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1));

_sum0 = _mm256_add_epi32(_sum0, _s0);
_sum1 = _mm256_add_epi32(_sum1, _s1);
Expand All @@ -3470,48 +3467,54 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M

if (k_end)
{
// 00 11 22 33 40 51 62 73
// 03 10 21 32 43 50 61 72
// 02 13 20 31 42 53 60 71
// 01 12 23 30 41 52 63 70
if (out_elempack == 8)
{
// TODO
// 00 11 22 33 40 51 62 73
// 01 12 23 30 41 52 63 70
// 20 31 02 13 60 71 42 53
// 21 32 03 10 61 72 43 50

_sum0 = _sum0;
_sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 3, 2, 1));
_sum2 = _mm256_shuffle_epi32(_sum2, _MM_SHUFFLE(1, 0, 3, 2));
_sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));
__m256i _tmp0 = _sum0;
__m256i _tmp1 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));
__m256i _tmp2 = _sum2;
__m256i _tmp3 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));

// 00 11 22 33 40 51 62 73
// 10 21 32 03 50 61 72 43
// 20 31 02 13 60 71 42 53
// 30 01 12 23 70 41 52 63
// 00 11 22 33 40 51 62 73
// 10 21 32 03 50 61 72 43
// 20 31 02 13 60 71 42 53
// 30 01 12 23 70 41 52 63

__m256i _sum01l = _mm256_unpacklo_epi32(_sum0, _sum1);
__m256i _sum01h = _mm256_unpackhi_epi32(_sum0, _sum1);
__m256i _sum23l = _mm256_unpacklo_epi32(_sum2, _sum3);
__m256i _sum23h = _mm256_unpackhi_epi32(_sum2, _sum3);
_sum0 = _mm256_unpacklo_epi32(_tmp0, _tmp1);
_sum1 = _mm256_unpackhi_epi32(_tmp0, _tmp1);
_sum2 = _mm256_unpacklo_epi32(_tmp2, _tmp3);
_sum3 = _mm256_unpackhi_epi32(_tmp2, _tmp3);

// 00 10 11 21 40 50 51 61
// 22 32 33 03 62 72 73 43
// 20 30 31 01 60 70 71 41
// 02 12 13 23 42 52 53 63
// 00 10 11 21 40 50 51 61
// 22 32 33 03 62 72 73 43
// 20 30 31 01 60 70 71 41
// 02 12 13 23 42 52 53 63

_sum0 = _mm256_unpacklo_epi64(_sum01l, _sum23l);
_sum1 = _mm256_unpackhi_epi64(_sum01l, _sum23l);
_sum2 = _mm256_unpacklo_epi64(_sum01h, _sum23h);
_sum3 = _mm256_unpackhi_epi64(_sum01h, _sum23h);
_tmp0 = _mm256_unpacklo_epi64(_sum0, _sum2);
_tmp1 = _mm256_unpackhi_epi64(_sum0, _sum2);
_tmp2 = _mm256_unpacklo_epi64(_sum3, _sum1);
_tmp3 = _mm256_unpackhi_epi64(_sum3, _sum1);

// 00 10 20 30 40 50 60 70
// 11 21 31 01 51 61 71 41
// 22 32 02 12 62 72 42 52
// 33 03 13 23 73 43 53 63
// 00 10 20 30 40 50 60 70
// 11 21 31 01 51 61 71 41
// 02 12 22 32 42 52 62 72
// 13 23 33 03 53 63 73 43

_sum0 = _sum0;
_sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
_sum2 = _mm256_shuffle_epi32(_sum2, _MM_SHUFFLE(1, 0, 3, 2));
_sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(0, 3, 2, 1));
_sum0 = _sum0;
_sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
_sum2 = _sum2;
_sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));

// 00 10 20 30 40 50 60 70
// 01 11 21 31 41 51 61 71
// 02 12 22 32 42 52 62 72
// 03 13 23 33 43 53 63 73

if (out_elempack == 8)
{
_mm256_store_si256((__m256i*)outptr0, _sum0);
_mm256_store_si256((__m256i*)(outptr0 + 8), _sum1);
_mm256_store_si256((__m256i*)(outptr0 + 16), _sum2);
Expand All @@ -3520,30 +3523,117 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
}
if (out_elempack == 4)
{
__m256i _tmp0 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0));
__m256i _tmp1 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0));
__m256i _tmp2 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1));
__m256i _tmp3 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1));
// 00 11 22 33 40 51 62 73
// 01 12 23 30 41 52 63 70
// 20 31 02 13 60 71 42 53
// 21 32 03 10 61 72 43 50

_mm256_storeu_si256((__m256i*)outptr0, _tmp0);
_mm256_storeu_si256((__m256i*)(outptr0 + 8), _tmp1);
__m256i _tmp0 = _sum0;
__m256i _tmp1 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));
__m256i _tmp2 = _sum2;
__m256i _tmp3 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));

// 00 11 22 33 40 51 62 73
// 10 21 32 03 50 61 72 43
// 20 31 02 13 60 71 42 53
// 30 01 12 23 70 41 52 63

_sum0 = _mm256_unpacklo_epi32(_tmp0, _tmp1);
_sum1 = _mm256_unpackhi_epi32(_tmp0, _tmp1);
_sum2 = _mm256_unpacklo_epi32(_tmp2, _tmp3);
_sum3 = _mm256_unpackhi_epi32(_tmp2, _tmp3);

// 00 10 11 21 40 50 51 61
// 22 32 33 03 62 72 73 43
// 20 30 31 01 60 70 71 41
// 02 12 13 23 42 52 53 63

_tmp0 = _mm256_unpacklo_epi64(_sum0, _sum2);
_tmp1 = _mm256_unpackhi_epi64(_sum0, _sum2);
_tmp2 = _mm256_unpacklo_epi64(_sum3, _sum1);
_tmp3 = _mm256_unpackhi_epi64(_sum3, _sum1);

// 00 10 20 30 40 50 60 70
// 11 21 31 01 51 61 71 41
// 02 12 22 32 42 52 62 72
// 13 23 33 03 53 63 73 43

_tmp0 = _tmp0;
_tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(2, 1, 0, 3));
_tmp2 = _tmp2;
_tmp3 = _mm256_shuffle_epi32(_tmp3, _MM_SHUFFLE(2, 1, 0, 3));

// 00 10 20 30 40 50 60 70
// 01 11 21 31 41 51 61 71
// 02 12 22 32 42 52 62 72
// 03 13 23 33 43 53 63 73

_sum0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0));
_sum1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0));
_sum2 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1));
_sum3 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1));

_mm256_storeu_si256((__m256i*)(outptr0 + out_hstep * 4), _tmp2);
_mm256_storeu_si256((__m256i*)(outptr0 + out_hstep * 4 + 8), _tmp3);
_mm256_storeu_si256((__m256i*)outptr0, _sum0);
_mm256_storeu_si256((__m256i*)(outptr0 + 8), _sum1);
_mm256_storeu_si256((__m256i*)(outptr0 + out_hstep * 4), _sum2);
_mm256_storeu_si256((__m256i*)(outptr0 + out_hstep * 4 + 8), _sum3);

outptr0 += 16;
}
if (out_elempack == 1)
{
transpose8x4_epi32(_sum0, _sum1, _sum2, _sum3);
// 00 11 22 33 40 51 62 73
// 01 12 23 30 41 52 63 70
// 20 31 02 13 60 71 42 53
// 21 32 03 10 61 72 43 50

_sum0 = _sum0;
_sum1 = _sum1;
_sum2 = _mm256_shuffle_epi32(_sum2, _MM_SHUFFLE(1, 0, 3, 2));
_sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(1, 0, 3, 2));

// 00 11 22 33 40 51 62 73
// 01 12 23 30 41 52 63 70
// 02 13 20 31 42 53 60 71
// 03 10 21 32 43 50 61 72

__m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum1);
__m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum1);
__m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum3);
__m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum3);

// 00 01 11 12 40 41 51 52
// 22 23 33 30 62 63 73 70
// 02 03 13 10 42 43 53 50
// 20 21 31 32 60 61 71 72

_sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2);
_sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2);
_sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1);
_sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1);

// 00 01 02 03 40 41 42 43
// 11 12 13 10 51 52 53 50
// 20 21 22 23 60 61 62 63
// 31 32 33 30 71 72 73 70

_sum0 = _sum0;
_sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
_sum2 = _sum2;
_sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));

// 00 01 02 03 40 41 42 43
// 10 11 12 13 50 51 52 53
// 20 21 22 23 60 61 62 63
// 30 31 32 33 70 71 72 73

_mm_storeu_si128((__m128i*)outptr0, _mm256_extracti128_si256(_sum0, 0));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep), _mm256_extracti128_si256(_sum0, 1));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 2), _mm256_extracti128_si256(_sum1, 0));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 3), _mm256_extracti128_si256(_sum1, 1));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 4), _mm256_extracti128_si256(_sum2, 0));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 5), _mm256_extracti128_si256(_sum2, 1));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 6), _mm256_extracti128_si256(_sum3, 0));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep), _mm256_extracti128_si256(_sum1, 0));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 2), _mm256_extracti128_si256(_sum2, 0));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 3), _mm256_extracti128_si256(_sum3, 0));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 4), _mm256_extracti128_si256(_sum0, 1));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 5), _mm256_extracti128_si256(_sum1, 1));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 6), _mm256_extracti128_si256(_sum2, 1));
_mm_storeu_si128((__m128i*)(outptr0 + out_hstep * 7), _mm256_extracti128_si256(_sum3, 1));
outptr0 += 4;
}
Expand Down

0 comments on commit 41f720c

Please sign in to comment.