From 3cb3aa7e85d890a9eb215198cc202a7709124b4c Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 30 Apr 2024 17:28:01 +0800 Subject: [PATCH] opt --- src/layer/x86/lstm_int8.h | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h index 22c74aca332..1b155458d34 100644 --- a/src/layer/x86/lstm_int8.h +++ b/src/layer/x86/lstm_int8.h @@ -2115,22 +2115,29 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d const float* gates_data = gates.row(q); - __m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data); - __m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4); - __m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8); - __m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12); - __m128 _IFOG_4x4_4 = _mm_loadu_ps(gates_data + 16); - __m128 _IFOG_4x4_5 = _mm_loadu_ps(gates_data + 20); - __m128 _IFOG_4x4_6 = _mm_loadu_ps(gates_data + 24); - __m128 _IFOG_4x4_7 = _mm_loadu_ps(gates_data + 28); + __m256 _IFOG_0 = _mm256_loadu_ps(gates_data); + __m256 _IFOG_2 = _mm256_loadu_ps(gates_data + 8); + __m256 _IFOG_4 = _mm256_loadu_ps(gates_data + 16); + __m256 _IFOG_6 = _mm256_loadu_ps(gates_data + 24); - _MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3); - _MM_TRANSPOSE4_PS(_IFOG_4x4_4, _IFOG_4x4_5, _IFOG_4x4_6, _IFOG_4x4_7); - - __m256 _lstm_I = sigmoid_avx(_mm256_insertf128_ps(_mm256_castps128_ps256(_IFOG_4x4_0), _IFOG_4x4_4, 1)); - __m256 _lstm_F = sigmoid_avx(_mm256_insertf128_ps(_mm256_castps128_ps256(_IFOG_4x4_1), _IFOG_4x4_5, 1)); - __m256 _lstm_O = sigmoid_avx(_mm256_insertf128_ps(_mm256_castps128_ps256(_IFOG_4x4_2), _IFOG_4x4_6, 1)); - __m256 _lstm_G = tanh_avx(_mm256_insertf128_ps(_mm256_castps128_ps256(_IFOG_4x4_3), _IFOG_4x4_7, 1)); + // unzip4 + __m256 _tmp0 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_4, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_IFOG_2, _IFOG_6, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_4, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp3 = _mm256_permute2f128_ps(_IFOG_2, _IFOG_6, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp4 = _mm256_unpacklo_ps(_tmp0, _tmp1); + __m256 _tmp5 = _mm256_unpacklo_ps(_tmp2, _tmp3); + __m256 _tmp6 = _mm256_unpackhi_ps(_tmp0, _tmp1); + __m256 _tmp7 = _mm256_unpackhi_ps(_tmp2, _tmp3); + __m256 _lstm_I = _mm256_unpacklo_ps(_tmp4, _tmp5); + __m256 _lstm_F = _mm256_unpackhi_ps(_tmp4, _tmp5); + __m256 _lstm_O = _mm256_unpacklo_ps(_tmp6, _tmp7); + __m256 _lstm_G = _mm256_unpackhi_ps(_tmp6, _tmp7); + + _lstm_I = sigmoid_avx(_lstm_I); + _lstm_F = sigmoid_avx(_lstm_F); + _lstm_O = sigmoid_avx(_lstm_O); + _lstm_G = tanh_avx(_lstm_G); __m256 _cell2 = _mm256_add_ps(_mm256_mul_ps(_lstm_F, _mm256_loadu_ps(cell_ptr + q)), _mm256_mul_ps(_lstm_I, _lstm_G)); __m256 _lstm_H = _mm256_mul_ps(_lstm_O, tanh_avx(_cell2));