From 970c12a47ef64e8ec152bc8963f61b1cd474ce25 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 30 Apr 2024 17:10:06 +0800 Subject: [PATCH] opt --- src/layer/x86/lstm_int8.h | 114 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 108 insertions(+), 6 deletions(-) diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h index f2e007e3ea8..22c74aca332 100644 --- a/src/layer/x86/lstm_int8.h +++ b/src/layer/x86/lstm_int8.h @@ -939,11 +939,11 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d float* gates_data = gates.row(q); __m512i _lstm_IFOGx0 = _mm512_setzero_si512(); - int i = 0; -#if __AVX512VNNI__ __m512i _sum1 = _mm512_setzero_si512(); __m512i _sum2 = _mm512_setzero_si512(); __m512i _sum3 = _mm512_setzero_si512(); + int i = 0; +#if __AVX512VNNI__ __m512i _v127 = _mm512_set1_epi8(127); for (; i + 15 < size; i += 16) { @@ -996,10 +996,61 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _lstm_IFOGx0 = _mm512_sub_epi32(_lstm_IFOGx0, _w_shift); kptr += 64; } +#else + for (; i + 7 < size; i += 8) + { + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + __m256i _xi = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(x + i))); + + __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); + __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); + __m512i _ww2 = _mm512_cvtepi8_epi16(_w2); + __m512i _ww3 = _mm512_cvtepi8_epi16(_w3); + __m512i _xixi = _mm512_cvtepi8_epi16(_xi); + + __m512i _xixi0 = _mm512_shuffle_epi32(_xixi, _MM_PERM_AAAA); + __m512i _xixi1 = _mm512_shuffle_epi32(_xixi, _MM_PERM_BBBB); + __m512i _xixi2 = _mm512_shuffle_epi32(_xixi, _MM_PERM_CCCC); + __m512i _xixi3 = _mm512_shuffle_epi32(_xixi, _MM_PERM_DDDD); + + __m512i _s0 = _mm512_madd_epi16(_ww0, _xixi0); + __m512i _s1 = _mm512_madd_epi16(_ww1, _xixi1); + __m512i _s2 = _mm512_madd_epi16(_ww2, _xixi2); + __m512i _s3 = _mm512_madd_epi16(_ww3, _xixi3); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + kptr += 128; + } + for (; i + 3 < size; i += 4) + { + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); + + __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); + __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); + __m512i _xixi = _mm512_cvtepi8_epi16(_xi); + + __m512i _xixi0 = _mm512_shuffle_epi32(_xixi, _MM_PERM_AAAA); + __m512i _xixi1 = _mm512_shuffle_epi32(_xixi, _MM_PERM_BBBB); + + __m512i _s0 = _mm512_madd_epi16(_ww0, _xixi0); + __m512i _s1 = _mm512_madd_epi16(_ww1, _xixi1); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + kptr += 64; + } +#endif // __AVX512VNNI__ _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum1); _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum2); _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum3); -#endif // __AVX512VNNI__ for (; i + 1 < size; i += 2) { __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); @@ -1029,11 +1080,11 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d } __m512i _lstm_IFOGh0 = _mm512_setzero_si512(); - i = 0; -#if __AVX512VNNI__ _sum1 = _mm512_setzero_si512(); _sum2 = _mm512_setzero_si512(); _sum3 = _mm512_setzero_si512(); + i = 0; +#if __AVX512VNNI__ for (; i + 15 < num_output; i += 16) { __m512i _h_cont0 = _mm512_set1_epi32(((const int*)(hs + i))[0]); @@ -1085,10 +1136,61 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _lstm_IFOGh0 = _mm512_sub_epi32(_lstm_IFOGh0, _w_shift); kptr += 64; } +#else + for (; i + 7 < num_output; i += 8) + { + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + __m256i _h_cont = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(hs + i))); + + __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); + __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); + __m512i _ww2 = _mm512_cvtepi8_epi16(_w2); + __m512i _ww3 = _mm512_cvtepi8_epi16(_w3); + __m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont); + + __m512i _hh_cont0 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_AAAA); + __m512i _hh_cont1 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_BBBB); + __m512i _hh_cont2 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_CCCC); + __m512i _hh_cont3 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_DDDD); + + __m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont0); + __m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont1); + __m512i _s2 = _mm512_madd_epi16(_ww2, _hh_cont2); + __m512i _s3 = _mm512_madd_epi16(_ww3, _hh_cont3); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + kptr += 128; + } + for (; i + 3 < num_output; i += 4) + { + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); + + __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); + __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); + __m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont); + + __m512i _hh_cont0 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_AAAA); + __m512i _hh_cont1 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_BBBB); + + __m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont0); + __m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont1); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + kptr += 64; + } +#endif // __AVX512VNNI__ _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum1); _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum2); _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum3); -#endif // __AVX512VNNI__ for (; i + 1 < num_output; i += 2) { __m256i _w = _mm256_loadu_si256((const __m256i*)kptr);