From 31ba013661d129e5beb5950dacd0d5a8fedc090d Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 6 May 2024 17:46:41 +0800 Subject: [PATCH] opt --- src/layer/x86/lstm_int8.h | 404 +++++++++++++++++++++++++++----------- 1 file changed, 294 insertions(+), 110 deletions(-) diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h index bfc691cd3cb..67a9b89570e 100644 --- a/src/layer/x86/lstm_int8.h +++ b/src/layer/x86/lstm_int8.h @@ -1268,6 +1268,57 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x #if __AVXVNNI__ || __AVX512VNNI__ __m128i _w_shift = _mm_setzero_si128(); __m128i _v127 = _mm_set1_epi8(127); + __m128i _w0_shift = _mm_setzero_si128(); + __m128i _w1_shift = _mm_setzero_si128(); + __m128i _w2_shift = _mm_setzero_si128(); + __m128i _w3_shift = _mm_setzero_si128(); + for (; i + 15 < size; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_xc_I + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_xc_F + i))); + _mm_storeu_si128((__m128i*)(kptr + 32), _mm_loadu_si128((const __m128i*)(weight_xc_O + i))); + _mm_storeu_si128((__m128i*)(kptr + 48), _mm_loadu_si128((const __m128i*)(weight_xc_G + i))); + + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); + __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); + _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 64; + } + { + transpose4x4_epi32(_w0_shift, _w1_shift, _w2_shift, _w3_shift); + _w_shift = _mm_add_epi32(_w_shift, _w0_shift); + _w_shift = _mm_add_epi32(_w_shift, _w1_shift); + _w_shift = _mm_add_epi32(_w_shift, _w2_shift); + _w_shift = _mm_add_epi32(_w_shift, _w3_shift); + } + + _w0_shift = _mm_setzero_si128(); + _w1_shift = _mm_setzero_si128(); + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_F + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_xc_O + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_xc_G + i))); + + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 32; + } + { + __m128i _tmp0 = _mm_hadd_epi32(_w0_shift, _w1_shift); + _w_shift = _mm_add_epi32(_w_shift, _tmp0); + } + for (; i + 3 < size; i += 4) { kptr[0] = weight_xc_I[i]; @@ -1295,6 +1346,35 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storeu_si128((__m128i*)kptr, _w_shift); kptr += 16; +#else + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_F + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_xc_O + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_xc_G + i))); + kptr += 32; + } + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_I[i + 1]; + kptr[2] = weight_xc_I[i + 2]; + kptr[3] = weight_xc_I[i + 3]; + kptr[4] = weight_xc_F[i]; + kptr[5] = weight_xc_F[i + 1]; + kptr[6] = weight_xc_F[i + 2]; + kptr[7] = weight_xc_F[i + 3]; + kptr[8 + 0] = weight_xc_O[i]; + kptr[8 + 1] = weight_xc_O[i + 1]; + kptr[8 + 2] = weight_xc_O[i + 2]; + kptr[8 + 3] = weight_xc_O[i + 3]; + kptr[8 + 4] = weight_xc_G[i]; + kptr[8 + 5] = weight_xc_G[i + 1]; + kptr[8 + 6] = weight_xc_G[i + 2]; + kptr[8 + 7] = weight_xc_G[i + 3]; + kptr += 16; + } #endif // __AVXVNNI__ || __AVX512VNNI__ for (; i + 1 < size; i += 2) { @@ -1322,6 +1402,57 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x #if __SSE2__ #if __AVXVNNI__ || __AVX512VNNI__ _w_shift = _mm_setzero_si128(); + _w0_shift = _mm_setzero_si128(); + _w1_shift = _mm_setzero_si128(); + _w2_shift = _mm_setzero_si128(); + _w3_shift = _mm_setzero_si128(); + for (; i + 15 < num_output; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_hc_I + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_hc_F + i))); + _mm_storeu_si128((__m128i*)(kptr + 32), _mm_loadu_si128((const __m128i*)(weight_hc_O + i))); + _mm_storeu_si128((__m128i*)(kptr + 48), _mm_loadu_si128((const __m128i*)(weight_hc_G + i))); + + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); + __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); + _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 64; + } + { + transpose4x4_epi32(_w0_shift, _w1_shift, _w2_shift, _w3_shift); + _w_shift = _mm_add_epi32(_w_shift, _w0_shift); + _w_shift = _mm_add_epi32(_w_shift, _w1_shift); + _w_shift = _mm_add_epi32(_w_shift, _w2_shift); + _w_shift = _mm_add_epi32(_w_shift, _w3_shift); + } + + _w0_shift = _mm_setzero_si128(); + _w1_shift = _mm_setzero_si128(); + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_F + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_hc_O + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_hc_G + i))); + + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 32; + } + { + __m128i _tmp0 = _mm_hadd_epi32(_w0_shift, _w1_shift); + _w_shift = _mm_add_epi32(_w_shift, _tmp0); + } + for (; i + 3 < num_output; i += 4) { kptr[0] = weight_hc_I[i]; @@ -1349,6 +1480,35 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storeu_si128((__m128i*)kptr, _w_shift); kptr += 16; +#else + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_F + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_hc_O + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_hc_G + i))); + kptr += 32; + } + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_I[i + 1]; + kptr[2] = weight_hc_I[i + 2]; + kptr[3] = weight_hc_I[i + 3]; + kptr[4] = weight_hc_F[i]; + kptr[5] = weight_hc_F[i + 1]; + kptr[6] = weight_hc_F[i + 2]; + kptr[7] = weight_hc_F[i + 3]; + kptr[8 + 0] = weight_hc_O[i]; + kptr[8 + 1] = weight_hc_O[i + 1]; + kptr[8 + 2] = weight_hc_O[i + 2]; + kptr[8 + 3] = weight_hc_O[i + 3]; + kptr[8 + 4] = weight_hc_G[i]; + kptr[8 + 5] = weight_hc_G[i + 1]; + kptr[8 + 6] = weight_hc_G[i + 2]; + kptr[8 + 7] = weight_hc_G[i + 3]; + kptr += 16; + } #endif // __AVXVNNI__ || __AVX512VNNI__ for (; i + 1 < num_output; i += 2) { @@ -2251,6 +2411,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #if __SSE2__ __m128i _lstm_IFOGx0 = _mm_setzero_si128(); + __m128i _sum0 = _mm_setzero_si128(); __m128i _sum1 = _mm_setzero_si128(); __m128i _sum2 = _mm_setzero_si128(); __m128i _sum3 = _mm_setzero_si128(); @@ -2259,40 +2420,47 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _v127 = _mm_set1_epi8(127); for (; i + 15 < size; i += 16) { - __m128i _xi0 = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); - __m128i _xi1 = _mm_castps_si128(_mm_load1_ps((const float*)(x + i + 4))); - __m128i _xi2 = _mm_castps_si128(_mm_load1_ps((const float*)(x + i + 8))); - __m128i _xi3 = _mm_castps_si128(_mm_load1_ps((const float*)(x + i + 12))); + __m128i _xi = _mm_loadu_si128((const __m128i*)(x + i)); __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - _xi0 = _mm_add_epi8(_xi0, _v127); - _xi1 = _mm_add_epi8(_xi1, _v127); - _xi2 = _mm_add_epi8(_xi2, _v127); - _xi3 = _mm_add_epi8(_xi3, _v127); - _lstm_IFOGx0 = _mm_dpbusd_epi32(_lstm_IFOGx0, _xi0, _w0); - _sum1 = _mm_dpbusd_epi32(_sum1, _xi1, _w1); - _sum2 = _mm_dpbusd_epi32(_sum2, _xi2, _w2); - _sum3 = _mm_dpbusd_epi32(_sum3, _xi3, _w3); + _xi = _mm_add_epi8(_xi, _v127); + _sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0); + _sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1); + _sum2 = _mm_dpbusd_epi32(_sum2, _xi, _w2); + _sum3 = _mm_dpbusd_epi32(_sum3, _xi, _w3); kptr += 64; } + { + transpose4x4_epi32(_sum0, _sum1, _sum2, _sum3); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum3); + } + + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); for (; i + 7 < size; i += 8) { - __m128i _xi0 = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); - __m128i _xi1 = _mm_castps_si128(_mm_load1_ps((const float*)(x + i + 4))); + __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - _xi0 = _mm_add_epi8(_xi0, _v127); - _xi1 = _mm_add_epi8(_xi1, _v127); - _lstm_IFOGx0 = _mm_dpbusd_epi32(_lstm_IFOGx0, _xi0, _w0); - _sum1 = _mm_dpbusd_epi32(_sum1, _xi1, _w1); + _xi = _mm_add_epi8(_xi, _v127); + _sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0); + _sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1); kptr += 32; } + { + __m128i _tmp0 = _mm_hadd_epi32(_sum0, _sum1); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _tmp0); + } + for (; i + 3 < size; i += 4) { __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); @@ -2309,49 +2477,39 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 16; } #else -#if 1 for (; i + 7 < size; i += 8) { + __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); __m128i _w1 = _mm_loadl_epi64((const __m128i*)(kptr + 8)); __m128i _w2 = _mm_loadl_epi64((const __m128i*)(kptr + 16)); __m128i _w3 = _mm_loadl_epi64((const __m128i*)(kptr + 24)); - __m128i _xi0 = _mm_set1_epi16(((const short*)(x + i))[0]); - __m128i _xi1 = _mm_set1_epi16(((const short*)(x + i))[1]); - __m128i _xi2 = _mm_set1_epi16(((const short*)(x + i))[2]); - __m128i _xi3 = _mm_set1_epi16(((const short*)(x + i))[3]); #if __SSE4_1__ + _xi = _mm_cvtepi8_epi16(_xi); _w0 = _mm_cvtepi8_epi16(_w0); _w1 = _mm_cvtepi8_epi16(_w1); _w2 = _mm_cvtepi8_epi16(_w2); _w3 = _mm_cvtepi8_epi16(_w3); - _xi0 = _mm_cvtepi8_epi16(_xi0); - _xi1 = _mm_cvtepi8_epi16(_xi1); - _xi2 = _mm_cvtepi8_epi16(_xi2); - _xi3 = _mm_cvtepi8_epi16(_xi3); #else + _xi = _mm_unpacklo_epi8(_xi, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi)); _w0 = _mm_unpacklo_epi8(_w0, _mm_cmpgt_epi8(_mm_setzero_si128(), _w0)); _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); _w2 = _mm_unpacklo_epi8(_w2, _mm_cmpgt_epi8(_mm_setzero_si128(), _w2)); _w3 = _mm_unpacklo_epi8(_w3, _mm_cmpgt_epi8(_mm_setzero_si128(), _w3)); - _xi0 = _mm_unpacklo_epi8(_xi0, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi0)); - _xi1 = _mm_unpacklo_epi8(_xi1, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi1)); - _xi2 = _mm_unpacklo_epi8(_xi2, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi2)); - _xi3 = _mm_unpacklo_epi8(_xi3, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi3)); #endif #if __XOP__ - _lstm_IFOGx0 = _mm_maddd_epi16(_w0, _xi0, _lstm_IFOGx0); - _sum1 = _mm_maddd_epi16(_w1, _xi1, _sum1); - _sum2 = _mm_maddd_epi16(_w2, _xi2, _sum2); - _sum3 = _mm_maddd_epi16(_w3, _xi3, _sum3); + _sum0 = _mm_maddd_epi16(_w0, _xi, _sum0); + _sum1 = _mm_maddd_epi16(_w1, _xi, _sum1); + _sum2 = _mm_maddd_epi16(_w2, _xi, _sum2); + _sum3 = _mm_maddd_epi16(_w3, _xi, _sum3); #else - __m128i _s0 = _mm_madd_epi16(_w0, _xi0); - __m128i _s1 = _mm_madd_epi16(_w1, _xi1); - __m128i _s2 = _mm_madd_epi16(_w2, _xi2); - __m128i _s3 = _mm_madd_epi16(_w3, _xi3); - _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _s0); + __m128i _s0 = _mm_madd_epi16(_w0, _xi); + __m128i _s1 = _mm_madd_epi16(_w1, _xi); + __m128i _s2 = _mm_madd_epi16(_w2, _xi); + __m128i _s3 = _mm_madd_epi16(_w3, _xi); + _sum0 = _mm_add_epi32(_sum0, _s0); _sum1 = _mm_add_epi32(_sum1, _s1); _sum2 = _mm_add_epi32(_sum2, _s2); _sum3 = _mm_add_epi32(_sum3, _s3); @@ -2359,42 +2517,56 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 32; } + { + transpose4x4_epi32(_sum0, _sum1, _sum2, _sum3); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum3); + } + + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); for (; i + 3 < size; i += 4) { + __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); __m128i _w1 = _mm_loadl_epi64((const __m128i*)(kptr + 8)); - __m128i _xi0 = _mm_set1_epi16(((const short*)(x + i))[0]); - __m128i _xi1 = _mm_set1_epi16(((const short*)(x + i))[1]); #if __SSE4_1__ + _xi = _mm_cvtepi8_epi16(_xi); _w0 = _mm_cvtepi8_epi16(_w0); _w1 = _mm_cvtepi8_epi16(_w1); - _xi0 = _mm_cvtepi8_epi16(_xi0); - _xi1 = _mm_cvtepi8_epi16(_xi1); #else + _xi = _mm_unpacklo_epi8(_xi, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi)); _w0 = _mm_unpacklo_epi8(_w0, _mm_cmpgt_epi8(_mm_setzero_si128(), _w0)); _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); - _xi0 = _mm_unpacklo_epi8(_xi0, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi0)); - _xi1 = _mm_unpacklo_epi8(_xi1, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi1)); #endif #if __XOP__ - _lstm_IFOGx0 = _mm_maddd_epi16(_w0, _xi0, _lstm_IFOGx0); - _sum1 = _mm_maddd_epi16(_w1, _xi1, _sum1); + _sum0 = _mm_maddd_epi16(_w0, _xi, _sum0); + _sum1 = _mm_maddd_epi16(_w1, _xi, _sum1); #else - __m128i _s0 = _mm_madd_epi16(_w0, _xi0); - __m128i _s1 = _mm_madd_epi16(_w1, _xi1); - _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _s0); + __m128i _s0 = _mm_madd_epi16(_w0, _xi); + __m128i _s1 = _mm_madd_epi16(_w1, _xi); + _sum0 = _mm_add_epi32(_sum0, _s0); _sum1 = _mm_add_epi32(_sum1, _s1); #endif kptr += 16; } -#endif + { +#if __SSSE3__ + __m128i _tmp0 = _mm_hadd_epi32(_sum0, _sum1); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _tmp0); +#else + __m128i _tmp0 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_sum0), _mm_castsi128_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m128i _tmp1 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_sum0), _mm_castsi128_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _tmp0); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _tmp1); +#endif // __SSSE3__ + } #endif // __AVXVNNI__ || __AVX512VNNI__ - _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum1); - _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum2); - _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum3); for (; i + 1 < size; i += 2) { __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); @@ -2443,6 +2615,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d } __m128i _lstm_IFOGh0 = _mm_setzero_si128(); + _sum0 = _mm_setzero_si128(); _sum1 = _mm_setzero_si128(); _sum2 = _mm_setzero_si128(); _sum3 = _mm_setzero_si128(); @@ -2450,40 +2623,47 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #if __AVXVNNI__ || __AVX512VNNI__ for (; i + 15 < num_output; i += 16) { - __m128i _h_cont0 = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); - __m128i _h_cont1 = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i + 4))); - __m128i _h_cont2 = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i + 8))); - __m128i _h_cont3 = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i + 12))); + __m128i _h_cont = _mm_loadu_si128((const __m128i*)(hs + i)); __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - _h_cont0 = _mm_add_epi8(_h_cont0, _v127); - _h_cont1 = _mm_add_epi8(_h_cont1, _v127); - _h_cont2 = _mm_add_epi8(_h_cont2, _v127); - _h_cont3 = _mm_add_epi8(_h_cont3, _v127); - _lstm_IFOGh0 = _mm_dpbusd_epi32(_lstm_IFOGh0, _h_cont0, _w0); - _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont1, _w1); - _sum2 = _mm_dpbusd_epi32(_sum2, _h_cont2, _w2); - _sum3 = _mm_dpbusd_epi32(_sum3, _h_cont3, _w3); + _h_cont = _mm_add_epi8(_h_cont, _v127); + _sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0); + _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1); + _sum2 = _mm_dpbusd_epi32(_sum2, _h_cont, _w2); + _sum3 = _mm_dpbusd_epi32(_sum3, _h_cont, _w3); kptr += 64; } + { + transpose4x4_epi32(_sum0, _sum1, _sum2, _sum3); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum3); + } + + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); for (; i + 7 < num_output; i += 8) { - __m128i _h_cont0 = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); - __m128i _h_cont1 = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i + 4))); + __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - _h_cont0 = _mm_add_epi8(_h_cont0, _v127); - _h_cont1 = _mm_add_epi8(_h_cont1, _v127); - _lstm_IFOGh0 = _mm_dpbusd_epi32(_lstm_IFOGh0, _h_cont0, _w0); - _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont1, _w1); + _h_cont = _mm_add_epi8(_h_cont, _v127); + _sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0); + _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1); kptr += 32; } + { + __m128i _tmp0 = _mm_hadd_epi32(_sum0, _sum1); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _tmp0); + } + for (; i + 3 < num_output; i += 4) { __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); @@ -2500,49 +2680,39 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 16; } #else -#if 0 for (; i + 7 < num_output; i += 8) { + __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); __m128i _w1 = _mm_loadl_epi64((const __m128i*)(kptr + 8)); __m128i _w2 = _mm_loadl_epi64((const __m128i*)(kptr + 16)); __m128i _w3 = _mm_loadl_epi64((const __m128i*)(kptr + 24)); - __m128i _h_cont0 = _mm_set1_epi16(((const short*)(hs + i))[0]); - __m128i _h_cont1 = _mm_set1_epi16(((const short*)(hs + i))[1]); - __m128i _h_cont2 = _mm_set1_epi16(((const short*)(hs + i))[2]); - __m128i _h_cont3 = _mm_set1_epi16(((const short*)(hs + i))[3]); #if __SSE4_1__ + _h_cont = _mm_cvtepi8_epi16(_h_cont); _w0 = _mm_cvtepi8_epi16(_w0); _w1 = _mm_cvtepi8_epi16(_w1); _w2 = _mm_cvtepi8_epi16(_w2); _w3 = _mm_cvtepi8_epi16(_w3); - _h_cont0 = _mm_cvtepi8_epi16(_h_cont0); - _h_cont1 = _mm_cvtepi8_epi16(_h_cont1); - _h_cont2 = _mm_cvtepi8_epi16(_h_cont2); - _h_cont3 = _mm_cvtepi8_epi16(_h_cont3); #else + _h_cont = _mm_unpacklo_epi8(_h_cont, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont)); _w0 = _mm_unpacklo_epi8(_w0, _mm_cmpgt_epi8(_mm_setzero_si128(), _w0)); _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); _w2 = _mm_unpacklo_epi8(_w2, _mm_cmpgt_epi8(_mm_setzero_si128(), _w2)); _w3 = _mm_unpacklo_epi8(_w3, _mm_cmpgt_epi8(_mm_setzero_si128(), _w3)); - _h_cont0 = _mm_unpacklo_epi8(_h_cont0, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont0)); - _h_cont1 = _mm_unpacklo_epi8(_h_cont1, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont1)); - _h_cont2 = _mm_unpacklo_epi8(_h_cont2, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont2)); - _h_cont3 = _mm_unpacklo_epi8(_h_cont3, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont3)); #endif #if __XOP__ - _lstm_IFOGh0 = _mm_maddd_epi16(_w0, _h_cont0, _lstm_IFOGh0); - _sum1 = _mm_maddd_epi16(_w1, _h_cont1, _sum1); - _sum2 = _mm_maddd_epi16(_w2, _h_cont2, _sum2); - _sum3 = _mm_maddd_epi16(_w3, _h_cont3, _sum3); + _sum0 = _mm_maddd_epi16(_w0, _h_cont, _sum0); + _sum1 = _mm_maddd_epi16(_w1, _h_cont, _sum1); + _sum2 = _mm_maddd_epi16(_w2, _h_cont, _sum2); + _sum3 = _mm_maddd_epi16(_w3, _h_cont, _sum3); #else - __m128i _s0 = _mm_madd_epi16(_w0, _h_cont0); - __m128i _s1 = _mm_madd_epi16(_w1, _h_cont1); - __m128i _s2 = _mm_madd_epi16(_w2, _h_cont2); - __m128i _s3 = _mm_madd_epi16(_w3, _h_cont3); - _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _s0); + __m128i _s0 = _mm_madd_epi16(_w0, _h_cont); + __m128i _s1 = _mm_madd_epi16(_w1, _h_cont); + __m128i _s2 = _mm_madd_epi16(_w2, _h_cont); + __m128i _s3 = _mm_madd_epi16(_w3, _h_cont); + _sum0 = _mm_add_epi32(_sum0, _s0); _sum1 = _mm_add_epi32(_sum1, _s1); _sum2 = _mm_add_epi32(_sum2, _s2); _sum3 = _mm_add_epi32(_sum3, _s3); @@ -2550,42 +2720,56 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 32; } + { + transpose4x4_epi32(_sum0, _sum1, _sum2, _sum3); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum3); + } + + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); for (; i + 3 < num_output; i += 4) { + __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); __m128i _w1 = _mm_loadl_epi64((const __m128i*)(kptr + 8)); - __m128i _h_cont0 = _mm_set1_epi16(((const short*)(hs + i))[0]); - __m128i _h_cont1 = _mm_set1_epi16(((const short*)(hs + i))[1]); #if __SSE4_1__ + _h_cont = _mm_cvtepi8_epi16(_h_cont); _w0 = _mm_cvtepi8_epi16(_w0); _w1 = _mm_cvtepi8_epi16(_w1); - _h_cont0 = _mm_cvtepi8_epi16(_h_cont0); - _h_cont1 = _mm_cvtepi8_epi16(_h_cont1); #else + _h_cont = _mm_unpacklo_epi8(_h_cont, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont)); _w0 = _mm_unpacklo_epi8(_w0, _mm_cmpgt_epi8(_mm_setzero_si128(), _w0)); _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); - _h_cont0 = _mm_unpacklo_epi8(_h_cont0, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont0)); - _h_cont1 = _mm_unpacklo_epi8(_h_cont1, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont1)); #endif #if __XOP__ - _lstm_IFOGh0 = _mm_maddd_epi16(_w0, _h_cont0, _lstm_IFOGh0); - _sum1 = _mm_maddd_epi16(_w1, _h_cont1, _sum1); + _sum0 = _mm_maddd_epi16(_w0, _h_cont, _sum0); + _sum1 = _mm_maddd_epi16(_w1, _h_cont, _sum1); #else - __m128i _s0 = _mm_madd_epi16(_w0, _h_cont0); - __m128i _s1 = _mm_madd_epi16(_w1, _h_cont1); - _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _s0); + __m128i _s0 = _mm_madd_epi16(_w0, _h_cont); + __m128i _s1 = _mm_madd_epi16(_w1, _h_cont); + _sum0 = _mm_add_epi32(_sum0, _s0); _sum1 = _mm_add_epi32(_sum1, _s1); #endif kptr += 16; } -#endif + { +#if __SSSE3__ + __m128i _tmp0 = _mm_hadd_epi32(_sum0, _sum1); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _tmp0); +#else + __m128i _tmp0 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_sum0), _mm_castsi128_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m128i _tmp1 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_sum0), _mm_castsi128_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _tmp0); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _tmp1); +#endif // __SSSE3__ + } #endif // __AVXVNNI__ || __AVX512VNNI__ - _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum1); - _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum2); - _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum3); for (; i + 1 < num_output; i += 2) { __m128i _w = _mm_loadl_epi64((const __m128i*)kptr);