Skip to content

Commit

Permalink
opt
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 30, 2024
1 parent c8c41ed commit 970c12a
Showing 1 changed file with 108 additions and 6 deletions.
114 changes: 108 additions & 6 deletions src/layer/x86/lstm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -939,11 +939,11 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
float* gates_data = gates.row(q);

__m512i _lstm_IFOGx0 = _mm512_setzero_si512();
int i = 0;
#if __AVX512VNNI__
__m512i _sum1 = _mm512_setzero_si512();
__m512i _sum2 = _mm512_setzero_si512();
__m512i _sum3 = _mm512_setzero_si512();
int i = 0;
#if __AVX512VNNI__
__m512i _v127 = _mm512_set1_epi8(127);
for (; i + 15 < size; i += 16)
{
Expand Down Expand Up @@ -996,10 +996,61 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
_lstm_IFOGx0 = _mm512_sub_epi32(_lstm_IFOGx0, _w_shift);
kptr += 64;
}
#else
for (; i + 7 < size; i += 8)
{
__m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr);
__m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32));
__m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64));
__m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96));
__m256i _xi = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(x + i)));

__m512i _ww0 = _mm512_cvtepi8_epi16(_w0);
__m512i _ww1 = _mm512_cvtepi8_epi16(_w1);
__m512i _ww2 = _mm512_cvtepi8_epi16(_w2);
__m512i _ww3 = _mm512_cvtepi8_epi16(_w3);
__m512i _xixi = _mm512_cvtepi8_epi16(_xi);

__m512i _xixi0 = _mm512_shuffle_epi32(_xixi, _MM_PERM_AAAA);
__m512i _xixi1 = _mm512_shuffle_epi32(_xixi, _MM_PERM_BBBB);
__m512i _xixi2 = _mm512_shuffle_epi32(_xixi, _MM_PERM_CCCC);
__m512i _xixi3 = _mm512_shuffle_epi32(_xixi, _MM_PERM_DDDD);

__m512i _s0 = _mm512_madd_epi16(_ww0, _xixi0);
__m512i _s1 = _mm512_madd_epi16(_ww1, _xixi1);
__m512i _s2 = _mm512_madd_epi16(_ww2, _xixi2);
__m512i _s3 = _mm512_madd_epi16(_ww3, _xixi3);
_lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _s0);
_sum1 = _mm512_add_epi32(_sum1, _s1);
_sum2 = _mm512_add_epi32(_sum2, _s2);
_sum3 = _mm512_add_epi32(_sum3, _s3);

kptr += 128;
}
for (; i + 3 < size; i += 4)
{
__m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr);
__m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32));
__m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i)));

__m512i _ww0 = _mm512_cvtepi8_epi16(_w0);
__m512i _ww1 = _mm512_cvtepi8_epi16(_w1);
__m512i _xixi = _mm512_cvtepi8_epi16(_xi);

__m512i _xixi0 = _mm512_shuffle_epi32(_xixi, _MM_PERM_AAAA);
__m512i _xixi1 = _mm512_shuffle_epi32(_xixi, _MM_PERM_BBBB);

__m512i _s0 = _mm512_madd_epi16(_ww0, _xixi0);
__m512i _s1 = _mm512_madd_epi16(_ww1, _xixi1);
_lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _s0);
_sum1 = _mm512_add_epi32(_sum1, _s1);

kptr += 64;
}
#endif // __AVX512VNNI__
_lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum1);
_lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum2);
_lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum3);
#endif // __AVX512VNNI__
for (; i + 1 < size; i += 2)
{
__m256i _w = _mm256_loadu_si256((const __m256i*)kptr);
Expand Down Expand Up @@ -1029,11 +1080,11 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
}

__m512i _lstm_IFOGh0 = _mm512_setzero_si512();
i = 0;
#if __AVX512VNNI__
_sum1 = _mm512_setzero_si512();
_sum2 = _mm512_setzero_si512();
_sum3 = _mm512_setzero_si512();
i = 0;
#if __AVX512VNNI__
for (; i + 15 < num_output; i += 16)
{
__m512i _h_cont0 = _mm512_set1_epi32(((const int*)(hs + i))[0]);
Expand Down Expand Up @@ -1085,10 +1136,61 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
_lstm_IFOGh0 = _mm512_sub_epi32(_lstm_IFOGh0, _w_shift);
kptr += 64;
}
#else
for (; i + 7 < num_output; i += 8)
{
__m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr);
__m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32));
__m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64));
__m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96));
__m256i _h_cont = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(hs + i)));

__m512i _ww0 = _mm512_cvtepi8_epi16(_w0);
__m512i _ww1 = _mm512_cvtepi8_epi16(_w1);
__m512i _ww2 = _mm512_cvtepi8_epi16(_w2);
__m512i _ww3 = _mm512_cvtepi8_epi16(_w3);
__m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont);

__m512i _hh_cont0 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_AAAA);
__m512i _hh_cont1 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_BBBB);
__m512i _hh_cont2 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_CCCC);
__m512i _hh_cont3 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_DDDD);

__m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont0);
__m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont1);
__m512i _s2 = _mm512_madd_epi16(_ww2, _hh_cont2);
__m512i _s3 = _mm512_madd_epi16(_ww3, _hh_cont3);
_lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _s0);
_sum1 = _mm512_add_epi32(_sum1, _s1);
_sum2 = _mm512_add_epi32(_sum2, _s2);
_sum3 = _mm512_add_epi32(_sum3, _s3);

kptr += 128;
}
for (; i + 3 < num_output; i += 4)
{
__m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr);
__m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32));
__m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i)));

__m512i _ww0 = _mm512_cvtepi8_epi16(_w0);
__m512i _ww1 = _mm512_cvtepi8_epi16(_w1);
__m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont);

__m512i _hh_cont0 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_AAAA);
__m512i _hh_cont1 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_BBBB);

__m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont0);
__m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont1);
_lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _s0);
_sum1 = _mm512_add_epi32(_sum1, _s1);

kptr += 64;
}
#endif // __AVX512VNNI__
_lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum1);
_lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum2);
_lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum3);
#endif // __AVX512VNNI__
for (; i + 1 < num_output; i += 2)
{
__m256i _w = _mm256_loadu_si256((const __m256i*)kptr);
Expand Down

0 comments on commit 970c12a

Please sign in to comment.