Skip to content

Commit

Permalink
opt wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Aug 21, 2023
1 parent 67c4fd1 commit c8a3b82
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions src/layer/x86/convolution_im2col_gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -1233,11 +1233,13 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
__m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB));

__m512i _pA0 = _mm512_cvtepi8_epi16(_pA);
__m512i _pBBBB = _mm512_cvtepi8_epi16(_pB);
__m512i _pB0 = _mm512_cvtepi8_epi16(_pB);

// 01xx01xx01xx01xx -> 00000000... 11111111...
__m512i _pB0 = _mm512_shuffle_epi32(_pBBBB, _MM_PERM_AAAA);
__m512i _pB1 = _mm512_shuffle_epi32(_pBBBB, _MM_PERM_BBBB);
// 0123 4567 89ab cdef

// 0101 0101 0101 0101
// 1010 1010 1010 1010
__m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB);

#if __AVX512VNNI__
_sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0);
Expand All @@ -1253,10 +1255,16 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M
for (; kk < max_kk; kk += 1)
{
__m128i _pA = _mm_load_si128((const __m128i*)pA);
__m256i _pB0 = _mm256_set1_epi16(pB[0]);
__m256i _pB1 = _mm256_set1_epi16(pB[1]);
__m128i _pB = _mm_set1_epi16(((const short*)pB)[0]);

__m256i _pA0 = _mm256_cvtepi8_epi16(_pA);
__m256i _pB0 = _mm256_cvtepi8_epi16(_pB);

// 01234567 89abcdef

// 01010101 01010101
// 10101010 10101010
__m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 1, 0, 1)), _MM_SHUFFLE(0, 1, 0, 1));

__m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0));
__m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1));
Expand All @@ -1270,6 +1278,21 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M

if (k_end)
{
// 00 11 20 31 40 51 60 71 80 91 a0 b1 c0 d1 e0 f1
// 01 10 21 30 41 50 61 70 81 90 a1 b0 c1 d0 e1 f0

_sum0 = _mm512_shuffle_epi32(_sum0, _MM_PERM_DBCA);
_sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_ACDB);

// 00 20 11 31 40 60 51 71 80 a0 91 b1 c0 e0 d1 f1
// 10 30 21 01 50 70 61 41 90 b0 a1 81 d0 f0 e1 c1

__m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1);
__m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1);

_sum0 = _tmp0;
_sum1 = _mm512_shuffle_epi32(_tmp1, _MM_PERM_CBAD);

if (out_elempack == 16)
{
_mm512_storeu_si512((__m512i*)outptr0, _sum0);
Expand Down

0 comments on commit c8a3b82

Please sign in to comment.