Skip to content

Commit

Permalink
opt
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 30, 2024
1 parent 970c12a commit 3cb3aa7
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions src/layer/x86/lstm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -2115,22 +2115,29 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d

const float* gates_data = gates.row(q);

__m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data);
__m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4);
__m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8);
__m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12);
__m128 _IFOG_4x4_4 = _mm_loadu_ps(gates_data + 16);
__m128 _IFOG_4x4_5 = _mm_loadu_ps(gates_data + 20);
__m128 _IFOG_4x4_6 = _mm_loadu_ps(gates_data + 24);
__m128 _IFOG_4x4_7 = _mm_loadu_ps(gates_data + 28);
__m256 _IFOG_0 = _mm256_loadu_ps(gates_data);
__m256 _IFOG_2 = _mm256_loadu_ps(gates_data + 8);
__m256 _IFOG_4 = _mm256_loadu_ps(gates_data + 16);
__m256 _IFOG_6 = _mm256_loadu_ps(gates_data + 24);

_MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3);
_MM_TRANSPOSE4_PS(_IFOG_4x4_4, _IFOG_4x4_5, _IFOG_4x4_6, _IFOG_4x4_7);

__m256 _lstm_I = sigmoid_avx(_mm256_insertf128_ps(_mm256_castps128_ps256(_IFOG_4x4_0), _IFOG_4x4_4, 1));
__m256 _lstm_F = sigmoid_avx(_mm256_insertf128_ps(_mm256_castps128_ps256(_IFOG_4x4_1), _IFOG_4x4_5, 1));
__m256 _lstm_O = sigmoid_avx(_mm256_insertf128_ps(_mm256_castps128_ps256(_IFOG_4x4_2), _IFOG_4x4_6, 1));
__m256 _lstm_G = tanh_avx(_mm256_insertf128_ps(_mm256_castps128_ps256(_IFOG_4x4_3), _IFOG_4x4_7, 1));
// unzip4
__m256 _tmp0 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_4, _MM_SHUFFLE(0, 2, 0, 0));
__m256 _tmp1 = _mm256_permute2f128_ps(_IFOG_2, _IFOG_6, _MM_SHUFFLE(0, 2, 0, 0));
__m256 _tmp2 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_4, _MM_SHUFFLE(0, 3, 0, 1));
__m256 _tmp3 = _mm256_permute2f128_ps(_IFOG_2, _IFOG_6, _MM_SHUFFLE(0, 3, 0, 1));
__m256 _tmp4 = _mm256_unpacklo_ps(_tmp0, _tmp1);
__m256 _tmp5 = _mm256_unpacklo_ps(_tmp2, _tmp3);
__m256 _tmp6 = _mm256_unpackhi_ps(_tmp0, _tmp1);
__m256 _tmp7 = _mm256_unpackhi_ps(_tmp2, _tmp3);
__m256 _lstm_I = _mm256_unpacklo_ps(_tmp4, _tmp5);
__m256 _lstm_F = _mm256_unpackhi_ps(_tmp4, _tmp5);
__m256 _lstm_O = _mm256_unpacklo_ps(_tmp6, _tmp7);
__m256 _lstm_G = _mm256_unpackhi_ps(_tmp6, _tmp7);

_lstm_I = sigmoid_avx(_lstm_I);
_lstm_F = sigmoid_avx(_lstm_F);
_lstm_O = sigmoid_avx(_lstm_O);
_lstm_G = tanh_avx(_lstm_G);

__m256 _cell2 = _mm256_add_ps(_mm256_mul_ps(_lstm_F, _mm256_loadu_ps(cell_ptr + q)), _mm256_mul_ps(_lstm_I, _lstm_G));
__m256 _lstm_H = _mm256_mul_ps(_lstm_O, tanh_avx(_cell2));
Expand Down

0 comments on commit 3cb3aa7

Please sign in to comment.