diff --git a/src/layer/arm/lstm_int8.h b/src/layer/arm/lstm_int8.h index 9ad7de33551..c29417b36da 100644 --- a/src/layer/arm/lstm_int8.h +++ b/src/layer/arm/lstm_int8.h @@ -78,6 +78,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x float* descales_ptr = weight_data_tm_int8_descales_dr.row(q); int i = 0; +#if __ARM_NEON #if __ARM_FEATURE_DOTPROD for (; i + 3 < size; i += 4) { @@ -112,6 +113,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr[7] = weight_xc_G[i + 1]; kptr += 8; } +#endif // __ARM_NEON for (; i < size; i++) { kptr[0] = weight_xc_I[i]; @@ -122,6 +124,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x } i = 0; +#if __ARM_NEON #if __ARM_FEATURE_DOTPROD for (; i + 3 < num_output; i += 4) { @@ -156,6 +159,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr[7] = weight_hc_G[i + 1]; kptr += 8; } +#endif // __ARM_NEON for (; i < num_output; i++) { kptr[0] = weight_hc_I[i]; diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h new file mode 100644 index 00000000000..125ba296621 --- /dev/null +++ b/src/layer/x86/lstm_int8.h @@ -0,0 +1,473 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt) +{ + // TODO dispatch + + weight_data_tm.create(size + num_output, hidden_size, num_directions, 4u, 4); + weight_data_tm_int8_descales.create(4 + 4, hidden_size, num_directions); + bias_c_tm.create(hidden_size, 1, num_directions, 16u, 4); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc_dr = weight_xc.channel(dr); + const Mat weight_hc_dr = weight_hc.channel(dr); + const Mat bias_c_dr = bias_c.channel(dr); + const float* weight_xc_int8_scales_ptr = weight_xc_int8_scales.row(dr); + const float* weight_hc_int8_scales_ptr = weight_hc_int8_scales.row(dr); + + Mat weight_data_tm_dr = weight_data_tm.channel(dr); + Mat bias_c_tm_dr = bias_c_tm.channel(dr); + Mat weight_data_tm_int8_descales_dr = weight_data_tm_int8_descales.channel(dr); + + const float* bias_c_I = bias_c_dr.row(0); + const float* bias_c_F = bias_c_dr.row(1); + const float* bias_c_O = bias_c_dr.row(2); + const float* bias_c_G = bias_c_dr.row(3); + + float* bias_c_IFOG = bias_c_tm_dr.row(0); + + int q = 0; + for (; q < hidden_size; q++) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + + bias_c_IFOG += 4; + + const signed char* weight_xc_I = weight_xc_dr.row(hidden_size * 0 + q); + const signed char* weight_xc_F = weight_xc_dr.row(hidden_size * 1 + q); + const signed char* weight_xc_O = weight_xc_dr.row(hidden_size * 2 + q); + const signed char* weight_xc_G = weight_xc_dr.row(hidden_size * 3 + q); + + const signed char* weight_hc_I = weight_hc_dr.row(hidden_size * 0 + q); + const signed char* weight_hc_F = weight_hc_dr.row(hidden_size * 1 + q); + const signed char* weight_hc_O = weight_hc_dr.row(hidden_size * 2 + q); + const signed char* weight_hc_G = weight_hc_dr.row(hidden_size * 3 + q); + + signed char* kptr = weight_data_tm_dr.row(q); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q); + + int i = 0; +#if __SSE2__ + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_I[i + 1]; + kptr[2] = weight_xc_F[i]; + kptr[3] = weight_xc_F[i + 1]; + kptr[4] = weight_xc_O[i]; + kptr[5] = weight_xc_O[i + 1]; + kptr[6] = weight_xc_G[i]; + kptr[7] = weight_xc_G[i + 1]; + kptr += 8; + } +#endif // __SSE2__ + for (; i < size; i++) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_F[i]; + kptr[2] = weight_xc_O[i]; + kptr[3] = weight_xc_G[i]; + kptr += 4; + } + + i = 0; +#if __SSE2__ + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_I[i + 1]; + kptr[2] = weight_hc_F[i]; + kptr[3] = weight_hc_F[i + 1]; + kptr[4] = weight_hc_O[i]; + kptr[5] = weight_hc_O[i + 1]; + kptr[6] = weight_hc_G[i]; + kptr[7] = weight_hc_G[i + 1]; + kptr += 8; + } +#endif // __SSE2__ + for (; i < num_output; i++) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_F[i]; + kptr[2] = weight_hc_O[i]; + kptr[3] = weight_hc_G[i]; + kptr += 4; + } + + descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q]; + descales_ptr[1] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q]; + descales_ptr[2] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q]; + descales_ptr[3] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q]; + descales_ptr[4] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q]; + descales_ptr[5] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q]; + descales_ptr[6] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q]; + descales_ptr[7] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q]; + } + } +} + +static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + // TODO dispatch + + int size = bottom_blob_int8.w; + int T = bottom_blob_int8.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + } + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + float hidden_state_int8_scale = 1.f; + float hidden_state_int8_descale = 1.f; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scale = 127.f / absmax; + hidden_state_int8_descale = absmax / 127.f; + + signed char* hs = hidden_state_int8; + for (int i = 0; i < num_output; i++) + { + hs[i] = float2int8(hidden_state[i] * hidden_state_int8_scale); + } + } + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + // gate reset update + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + + const signed char* kptr = weight_data_tm.row(q); + const float* descales_ptr = weight_data_tm_int8_descales.row(q); + + float* gates_data = gates.row(q); + +#if __SSE2__ + __m128i _lstm_IFOGx0 = _mm_setzero_si128(); + int i = 0; + for (; i + 1 < size; i += 2) + { + __m128i _xi = _mm_set1_epi16(((const short*)(x + i))[0]); + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + +#if __SSE4_1__ + _w = _mm_cvtepi8_epi16(_w); + _xi = _mm_cvtepi8_epi16(_xi); +#else + _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); + _xi = _mm_unpacklo_epi8(_xi, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi)); +#endif + +#if __XOP__ + _lstm_IFOGx0 = _mm_maddd_epi16(_w, _xi, _lstm_IFOGx0); +#else + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _mm_madd_epi16(_w, _xi)); +#endif + + kptr += 8; + } + for (; i < size; i++) + { + __m128i _xi = _mm_set1_epi16(x[i]); + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + +#if __SSE4_1__ + _w = _mm_cvtepi8_epi16(_w); +#else + _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); +#endif + +#if __XOP__ + _w = _mm_unpacklo_epi16(_w, _w); + + _lstm_IFOGx0 = _mm_maccd_epi16(_w, _xi, _lstm_IFOGx0); +#else + __m128i _sl = _mm_mullo_epi16(_w, _xi); + __m128i _sh = _mm_mulhi_epi16(_w, _xi); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _s0); +#endif + + kptr += 4; + } + + __m128i _lstm_IFOGh0 = _mm_setzero_si128(); + i = 0; + for (; i + 1 < num_output; i += 2) + { + __m128i _h_cont = _mm_set1_epi16(((const short*)(hs + i))[0]); + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + +#if __SSE4_1__ + _w = _mm_cvtepi8_epi16(_w); + _h_cont = _mm_cvtepi8_epi16(_h_cont); +#else + _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); + _h_cont = _mm_unpacklo_epi8(_h_cont, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont)); +#endif + +#if __XOP__ + _lstm_IFOGh0 = _mm_maddd_epi16(_w, _h_cont, _lstm_IFOGh0); +#else + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _mm_madd_epi16(_w, _h_cont)); +#endif + + kptr += 8; + } + for (; i < num_output; i++) + { + __m128i _h_cont = _mm_set1_epi16(hs[i]); + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + +#if __SSE4_1__ + _w = _mm_cvtepi8_epi16(_w); +#else + _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); +#endif + +#if __XOP__ + _w = _mm_unpacklo_epi16(_w, _w); + + _lstm_IFOGh0 = _mm_maccd_epi16(_w, _h_cont, _lstm_IFOGh0); +#else + __m128i _sl = _mm_mullo_epi16(_w, _h_cont); + __m128i _sh = _mm_mulhi_epi16(_w, _h_cont); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _s0); +#endif + + kptr += 4; + } + + __m128 _descale_x = _mm_set1_ps(descale_x); + __m128 _descale_h = _mm_set1_ps(descale_h); + + __m128 _lstm_IFOG0 = _mm_loadu_ps(bias_c_IFOG); + + __m128 _descale_xc_IFOG = _mm_loadu_ps(descales_ptr); + + _lstm_IFOG0 = _mm_add_ps(_lstm_IFOG0, _mm_mul_ps(_mm_cvtepi32_ps(_lstm_IFOGx0), _mm_mul_ps(_descale_x, _descale_xc_IFOG))); + + __m128 _descale_hc_IFOG = _mm_loadu_ps(descales_ptr + 4); + + _lstm_IFOG0 = _mm_add_ps(_lstm_IFOG0, _mm_mul_ps(_mm_cvtepi32_ps(_lstm_IFOGh0), _mm_mul_ps(_descale_h, _descale_hc_IFOG))); + + _mm_storeu_ps(gates_data, _lstm_IFOG0); +#else + int Ix = 0; + int Fx = 0; + int Ox = 0; + int Gx = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Ix += kptr[0] * xi; + Fx += kptr[1] * xi; + Ox += kptr[2] * xi; + Gx += kptr[3] * xi; + + kptr += 4; + } + + int Ih = 0; + int Fh = 0; + int Oh = 0; + int Gh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Ih += kptr[0] * h_cont; + Fh += kptr[1] * h_cont; + Oh += kptr[2] * h_cont; + Gh += kptr[3] * h_cont; + + kptr += 4; + } + + const float descale_xc_I = descales_ptr[0]; + const float descale_xc_F = descales_ptr[1]; + const float descale_xc_O = descales_ptr[2]; + const float descale_xc_G = descales_ptr[3]; + const float descale_hc_I = descales_ptr[4]; + const float descale_hc_F = descales_ptr[5]; + const float descale_hc_O = descales_ptr[6]; + const float descale_hc_G = descales_ptr[7]; + + float I = bias_c_IFOG[0] + Ix * (descale_x * descale_xc_I) + Ih * (descale_h * descale_hc_I); + float F = bias_c_IFOG[1] + Fx * (descale_x * descale_xc_F) + Fh * (descale_h * descale_hc_F); + float O = bias_c_IFOG[2] + Ox * (descale_x * descale_xc_O) + Oh * (descale_h * descale_hc_O); + float G = bias_c_IFOG[3] + Gx * (descale_x * descale_xc_G) + Gh * (descale_h * descale_hc_G); + + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; +#endif // __SSE2__ + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + float* output_data = top_blob.row(ti); + + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + + int remain_hidden_size_start = 0; +#if __SSE2__ + int nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 4; + + const float* gates_data = gates.row(q); + + __m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data); + __m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4); + __m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8); + __m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12); + + _MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3); + + __m128 _lstm_I = sigmoid_sse(_IFOG_4x4_0); + __m128 _lstm_F = sigmoid_sse(_IFOG_4x4_1); + __m128 _lstm_O = sigmoid_sse(_IFOG_4x4_2); + __m128 _lstm_G = tanh_sse(_IFOG_4x4_3); + + __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_lstm_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_lstm_I, _lstm_G)); + __m128 _lstm_H = _mm_mul_ps(_lstm_O, tanh_sse(_cell2)); + + _mm_storeu_ps(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + _mm_storeu_ps(hidden_ptr + q, _lstm_H); + _mm_storeu_ps(output_data + q, _lstm_H); + } + else + { + _mm_storeu_ps(tmp_hidden_ptr + q, _lstm_H); + } + } +#endif // __SSE2__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_ptr[q] + I * G; + float H = O * tanhf(cell2); + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = H; + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = H; + } + } + } +} diff --git a/src/layer/x86/lstm_x86.cpp b/src/layer/x86/lstm_x86.cpp index 480963a3ba5..e82d5d2de2f 100644 --- a/src/layer/x86/lstm_x86.cpp +++ b/src/layer/x86/lstm_x86.cpp @@ -28,6 +28,8 @@ namespace ncnn { +#include "lstm_int8.h" + LSTM_x86::LSTM_x86() { one_blob_only = false; @@ -565,676 +567,230 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w return 0; } -#if NCNN_INT8 -static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { - int size = bottom_blob.w; +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blob, top_blob, opt); + } +#endif + int T = bottom_blob.h; - int num_output = top_blob.w; - int hidden_size = cell_state.w; + int num_directions = direction == 2 ? 2 : 1; - // 4 x hidden_size - Mat gates(4, hidden_size, 4u, opt.workspace_allocator); - if (gates.empty()) + // initial hidden state + Mat hidden(num_output, 4u, opt.workspace_allocator); + if (hidden.empty()) return -100; + hidden.fill(0.f); - Mat tmp_hidden_state; - if (num_output != hidden_size) + Mat cell(hidden_size, 4u, opt.workspace_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // Uni directional + if (direction == 0 || direction == 1) { - tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); - if (tmp_hidden_state.empty()) - return -100; + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; } - // unroll - for (int t = 0; t < T; t++) + if (direction == 2) { - // clip hidden by continuation indicator - // h_cont_{t-1} = cont_t * h_{t-1} - // h_cont_{t-1} = h_{t-1} if cont_t == 1 - // 0 otherwise - // calculate hidden - // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; - int ti = reverse ? T - 1 - t : t; + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; -#if __AVX__ - int nn_hidden_size = hidden_size >> 1; - int remain_hidden_size_start = nn_hidden_size << 1; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_hidden_size; qq++) { - int q = qq * 2; + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } - const float* bias_c_IFOG = (const float*)bias_c + q * 4; + hidden.fill(0.0f); + cell.fill(0.0f); - // gate I F O G - const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q / 2); - const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q / 2); + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } - const float* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q / 2); - const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q / 2); + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); - __m256 _descale_xc_IFOG = _mm256_loadu_ps(weight_xc_int8_descales_IFOG); - __m256 _descale_hc_IFOG = _mm256_loadu_ps(weight_hc_int8_descales_IFOG); + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } - __m256 _IFOG = _mm256_loadu_ps(bias_c_IFOG); - __m256 _sum1 = _mm256_setzero_ps(); - __m256 _sum2 = _mm256_setzero_ps(); - __m256 _sum3 = _mm256_setzero_ps(); + return 0; +} - const float* x = bottom_blob.row(ti); +int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + } +#endif - int i = 0; - for (; i + 3 < size; i += 4) - { - __m256 _xi0 = _mm256_broadcast_ss(x); - __m256 _xi1 = _mm256_broadcast_ss(x + 1); - __m256 _xi2 = _mm256_broadcast_ss(x + 2); - __m256 _xi3 = _mm256_broadcast_ss(x + 3); + const Mat& bottom_blob = bottom_blobs[0]; + int T = bottom_blob.h; + int num_directions = direction == 2 ? 2 : 1; - __m128i _weight_xc_IFOG0l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)weight_xc_int8_IFOG)); - __m128i _weight_xc_IFOG0h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 4))); - __m128i _weight_xc_IFOG1l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 8))); - __m128i _weight_xc_IFOG1h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 12))); - __m128i _weight_xc_IFOG2l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 16))); - __m128i _weight_xc_IFOG2h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 20))); - __m128i _weight_xc_IFOG3l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 24))); - __m128i _weight_xc_IFOG3h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 28))); - __m256 _weight_xc_IFOG0 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOG0l), _weight_xc_IFOG0h, 1)); - __m256 _weight_xc_IFOG1 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOG1l), _weight_xc_IFOG1h, 1)); - __m256 _weight_xc_IFOG2 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOG2l), _weight_xc_IFOG2h, 1)); - __m256 _weight_xc_IFOG3 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOG3l), _weight_xc_IFOG3h, 1)); - _weight_xc_IFOG0 = _mm256_mul_ps(_weight_xc_IFOG0, _descale_xc_IFOG); - _weight_xc_IFOG1 = _mm256_mul_ps(_weight_xc_IFOG1, _descale_xc_IFOG); - _weight_xc_IFOG2 = _mm256_mul_ps(_weight_xc_IFOG2, _descale_xc_IFOG); - _weight_xc_IFOG3 = _mm256_mul_ps(_weight_xc_IFOG3, _descale_xc_IFOG); + Mat hidden; + Mat cell; + Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 3) + { + hidden = bottom_blobs[1].clone(hidden_cell_allocator); + cell = bottom_blobs[2].clone(hidden_cell_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); - _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); - _sum1 = _mm256_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); - _sum2 = _mm256_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); - _sum3 = _mm256_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); + cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + } - x += 4; - weight_xc_int8_IFOG += 32; - } - for (; i < size; i++) - { - __m256 _xi = _mm256_broadcast_ss(x); + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; - __m128i _weight_xc_IFOGl = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)weight_xc_int8_IFOG)); - __m128i _weight_xc_IFOGh = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 4))); - __m256 _weight_xc_IFOG = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOGl), _weight_xc_IFOGh, 1)); - _weight_xc_IFOG = _mm256_mul_ps(_weight_xc_IFOG, _descale_xc_IFOG); + // Uni directional + if (direction == 0 || direction == 1) + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } - _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG); + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; - x += 1; - weight_xc_int8_IFOG += 8; - } + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; - const float* hidden_ptr = hidden_state; + Mat hidden0 = hidden.row_range(0, 1); + Mat cell0 = cell.row_range(0, 1); + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } - i = 0; - for (; i + 3 < num_output; i += 4) - { - __m256 _h_cont0 = _mm256_broadcast_ss(hidden_ptr); - __m256 _h_cont1 = _mm256_broadcast_ss(hidden_ptr + 1); - __m256 _h_cont2 = _mm256_broadcast_ss(hidden_ptr + 2); - __m256 _h_cont3 = _mm256_broadcast_ss(hidden_ptr + 3); + Mat hidden1 = hidden.row_range(1, 1); + Mat cell1 = cell.row_range(1, 1); + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } - __m128i _weight_hc_IFOG0l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)weight_hc_int8_IFOG)); - __m128i _weight_hc_IFOG0h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 4))); - __m128i _weight_hc_IFOG1l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 8))); - __m128i _weight_hc_IFOG1h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 12))); - __m128i _weight_hc_IFOG2l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 16))); - __m128i _weight_hc_IFOG2h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 20))); - __m128i _weight_hc_IFOG3l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 24))); - __m128i _weight_hc_IFOG3h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 28))); - __m256 _weight_hc_IFOG0 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOG0l), _weight_hc_IFOG0h, 1)); - __m256 _weight_hc_IFOG1 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOG1l), _weight_hc_IFOG1h, 1)); - __m256 _weight_hc_IFOG2 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOG2l), _weight_hc_IFOG2h, 1)); - __m256 _weight_hc_IFOG3 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOG3l), _weight_hc_IFOG3h, 1)); - _weight_hc_IFOG0 = _mm256_mul_ps(_weight_hc_IFOG0, _descale_hc_IFOG); - _weight_hc_IFOG1 = _mm256_mul_ps(_weight_hc_IFOG1, _descale_hc_IFOG); - _weight_hc_IFOG2 = _mm256_mul_ps(_weight_hc_IFOG2, _descale_hc_IFOG); - _weight_hc_IFOG3 = _mm256_mul_ps(_weight_hc_IFOG3, _descale_hc_IFOG); + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); - _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); - _sum1 = _mm256_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); - _sum2 = _mm256_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); - _sum3 = _mm256_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } - hidden_ptr += 4; - weight_hc_int8_IFOG += 32; - } - for (; i < num_output; i++) - { - __m256 _h_cont = _mm256_broadcast_ss(hidden_ptr); + if (top_blobs.size() == 3) + { + top_blobs[1] = hidden; + top_blobs[2] = cell; + } - __m128i _weight_hc_IFOGl = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)weight_hc_int8_IFOG)); - __m128i _weight_hc_IFOGh = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 4))); - __m256 _weight_hc_IFOG = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOGl), _weight_hc_IFOGh, 1)); - _weight_hc_IFOG = _mm256_mul_ps(_weight_hc_IFOG, _descale_hc_IFOG); + return 0; +} - _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG); +#if NCNN_INT8 +int LSTM_x86::create_pipeline_int8(const Option& opt) +{ + // pack IFOG + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / hidden_size / 4; - hidden_ptr += 1; - weight_hc_int8_IFOG += 8; - } + lstm_transform_weight_int8(weight_xc_data, weight_xc_data_int8_scales, weight_hc_data, weight_hc_data_int8_scales, bias_c_data, weight_data_tm, weight_data_tm_int8_descales, bias_c_data_packed, size, num_output, num_directions, hidden_size, opt); - float* gates_data = gates.row(q); + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + weight_xc_data_int8_scales.release(); + weight_hc_data_int8_scales.release(); + } - _IFOG = _mm256_add_ps(_IFOG, _sum1); - _sum2 = _mm256_add_ps(_sum2, _sum3); - _IFOG = _mm256_add_ps(_IFOG, _sum2); + return 0; +} - _mm256_storeu_ps(gates_data, _IFOG); - } -#else - int nn_hidden_size = 0; - int remain_hidden_size_start = 0; -#endif // __AVX__ +void LSTM_x86::dynamic_quantize(const Mat& bottom_blob, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const +{ + int size = bottom_blob.w; + int T = bottom_blob.h; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_hidden_size_start; q < hidden_size; q++) + // dynamic quantize bottom_blob + bottom_blob_int8_descales.create(T, (size_t)4u, 1, opt.blob_allocator); + + Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.blob_allocator); + + // fp32 + for (int t = 0; t < T; t++) + { + const float* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) { - const float* bias_c_IFOG = (const float*)bias_c + q * 4; + absmax = std::max(absmax, (float)fabs(x[i])); + } - // gate I F O G -#if __AVX__ - const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q / 2 + q % 2); - const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q / 2 + q % 2); - const float* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q / 2 + q % 2); - const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q / 2 + q % 2); -#else - const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q); - const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q); - const float* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q); - const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q); -#endif + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } -#if __SSE2__ - __m128 _descale_xc_IFOG = _mm_loadu_ps(weight_xc_int8_descales_IFOG); - __m128 _descale_hc_IFOG = _mm_loadu_ps(weight_hc_int8_descales_IFOG); + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt); +} - __m128 _IFOG = _mm_loadu_ps(bias_c_IFOG); - __m128 _sum1 = _mm_setzero_ps(); - __m128 _sum2 = _mm_setzero_ps(); - __m128 _sum3 = _mm_setzero_ps(); -#else // __SSE2__ - const float descale_xc_I = weight_xc_int8_descales_IFOG[0]; - const float descale_xc_F = weight_xc_int8_descales_IFOG[1]; - const float descale_xc_O = weight_xc_int8_descales_IFOG[2]; - const float descale_xc_G = weight_xc_int8_descales_IFOG[3]; +int LSTM_x86::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int T = bottom_blob.h; - const float descale_hc_I = weight_hc_int8_descales_IFOG[0]; - const float descale_hc_F = weight_hc_int8_descales_IFOG[1]; - const float descale_hc_O = weight_hc_int8_descales_IFOG[2]; - const float descale_hc_G = weight_hc_int8_descales_IFOG[3]; - - float I = bias_c_IFOG[0]; - float F = bias_c_IFOG[1]; - float O = bias_c_IFOG[2]; - float G = bias_c_IFOG[3]; -#endif // __SSE2__ - - const float* x = bottom_blob.row(ti); - - int i = 0; -#if __SSE2__ - for (; i + 3 < size; i += 4) - { - __m128 _xi0 = _mm_load1_ps(x); - __m128 _xi1 = _mm_load1_ps(x + 1); - __m128 _xi2 = _mm_load1_ps(x + 2); - __m128 _xi3 = _mm_load1_ps(x + 3); - - __m128i _weight_xc_IFOG = _mm_loadu_si128((const __m128i*)weight_xc_int8_IFOG); - __m128i _ext8 = _mm_cmpgt_epi8(_mm_setzero_si128(), _weight_xc_IFOG); - __m128i _weight_xc_IFOG01 = _mm_unpacklo_epi8(_weight_xc_IFOG, _ext8); - __m128i _weight_xc_IFOG23 = _mm_unpackhi_epi8(_weight_xc_IFOG, _ext8); - __m128i _ext16l = _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_xc_IFOG01); - __m128i _ext16h = _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_xc_IFOG23); - __m128 _weight_xc_IFOG0 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(_weight_xc_IFOG01, _ext16l)); - __m128 _weight_xc_IFOG1 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(_weight_xc_IFOG01, _ext16l)); - __m128 _weight_xc_IFOG2 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(_weight_xc_IFOG23, _ext16h)); - __m128 _weight_xc_IFOG3 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(_weight_xc_IFOG23, _ext16h)); - _weight_xc_IFOG0 = _mm_mul_ps(_weight_xc_IFOG0, _descale_xc_IFOG); - _weight_xc_IFOG1 = _mm_mul_ps(_weight_xc_IFOG1, _descale_xc_IFOG); - _weight_xc_IFOG2 = _mm_mul_ps(_weight_xc_IFOG2, _descale_xc_IFOG); - _weight_xc_IFOG3 = _mm_mul_ps(_weight_xc_IFOG3, _descale_xc_IFOG); - - _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); - _sum1 = _mm_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); - _sum2 = _mm_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); - _sum3 = _mm_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); - - x += 4; - weight_xc_int8_IFOG += 16; - } -#endif // __SSE2__ - for (; i < size; i++) - { -#if __SSE2__ - __m128 _xi = _mm_load1_ps(x); - - __m128i _weight_xc_IFOG = _mm_castpd_si128(_mm_load1_pd((const double*)weight_xc_int8_IFOG)); - _weight_xc_IFOG = _mm_unpacklo_epi8(_weight_xc_IFOG, _mm_cmpgt_epi8(_mm_setzero_si128(), _weight_xc_IFOG)); - _weight_xc_IFOG = _mm_unpacklo_epi16(_weight_xc_IFOG, _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_xc_IFOG)); - __m128 _weight_xc_IFOG0 = _mm_mul_ps(_mm_cvtepi32_ps(_weight_xc_IFOG), _descale_xc_IFOG); - - _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG0, _xi, _IFOG); -#else // __SSE2__ - float xi = x[0]; - I += xi * weight_xc_int8_IFOG[0] * descale_xc_I; - F += xi * weight_xc_int8_IFOG[1] * descale_xc_F; - O += xi * weight_xc_int8_IFOG[2] * descale_xc_O; - G += xi * weight_xc_int8_IFOG[3] * descale_xc_G; -#endif // __SSE2__ - - x += 1; - weight_xc_int8_IFOG += 4; - } - - const float* hidden_ptr = hidden_state; - - i = 0; -#if __SSE2__ - for (; i + 3 < num_output; i += 4) - { - __m128 _h_cont0 = _mm_load1_ps(hidden_ptr); - __m128 _h_cont1 = _mm_load1_ps(hidden_ptr + 1); - __m128 _h_cont2 = _mm_load1_ps(hidden_ptr + 2); - __m128 _h_cont3 = _mm_load1_ps(hidden_ptr + 3); - - __m128i _weight_hc_IFOG = _mm_loadu_si128((const __m128i*)weight_hc_int8_IFOG); - __m128i _ext8 = _mm_cmpgt_epi8(_mm_setzero_si128(), _weight_hc_IFOG); - __m128i _weight_hc_IFOG01 = _mm_unpacklo_epi8(_weight_hc_IFOG, _ext8); - __m128i _weight_hc_IFOG23 = _mm_unpackhi_epi8(_weight_hc_IFOG, _ext8); - __m128i _ext16l = _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_hc_IFOG01); - __m128i _ext16h = _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_hc_IFOG23); - __m128 _weight_hc_IFOG0 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(_weight_hc_IFOG01, _ext16l)); - __m128 _weight_hc_IFOG1 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(_weight_hc_IFOG01, _ext16l)); - __m128 _weight_hc_IFOG2 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(_weight_hc_IFOG23, _ext16h)); - __m128 _weight_hc_IFOG3 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(_weight_hc_IFOG23, _ext16h)); - _weight_hc_IFOG0 = _mm_mul_ps(_weight_hc_IFOG0, _descale_hc_IFOG); - _weight_hc_IFOG1 = _mm_mul_ps(_weight_hc_IFOG1, _descale_hc_IFOG); - _weight_hc_IFOG2 = _mm_mul_ps(_weight_hc_IFOG2, _descale_hc_IFOG); - _weight_hc_IFOG3 = _mm_mul_ps(_weight_hc_IFOG3, _descale_hc_IFOG); - - _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); - _sum1 = _mm_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); - _sum2 = _mm_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); - _sum3 = _mm_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); - - hidden_ptr += 4; - weight_hc_int8_IFOG += 16; - } -#endif // __SSE2__ - for (; i < num_output; i++) - { -#if __SSE2__ - __m128 _h_cont = _mm_load1_ps(hidden_ptr); - - __m128i _weight_hc_IFOG = _mm_castpd_si128(_mm_load1_pd((const double*)weight_hc_int8_IFOG)); - _weight_hc_IFOG = _mm_unpacklo_epi8(_weight_hc_IFOG, _mm_cmpgt_epi8(_mm_setzero_si128(), _weight_hc_IFOG)); - _weight_hc_IFOG = _mm_unpacklo_epi16(_weight_hc_IFOG, _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_hc_IFOG)); - __m128 _weight_hc_IFOG0 = _mm_mul_ps(_mm_cvtepi32_ps(_weight_hc_IFOG), _descale_hc_IFOG); - - _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont, _IFOG); -#else // __SSE2__ - float h_cont = hidden_ptr[0]; - I += h_cont * weight_hc_int8_IFOG[0] * descale_hc_I; - F += h_cont * weight_hc_int8_IFOG[1] * descale_hc_F; - O += h_cont * weight_hc_int8_IFOG[2] * descale_hc_O; - G += h_cont * weight_hc_int8_IFOG[3] * descale_hc_G; -#endif // __SSE2__ - - hidden_ptr += 1; - weight_hc_int8_IFOG += 4; - } - - float* gates_data = gates.row(q); - -#if __SSE2__ - _IFOG = _mm_add_ps(_IFOG, _sum1); - _sum2 = _mm_add_ps(_sum2, _sum3); - _IFOG = _mm_add_ps(_IFOG, _sum2); - - _mm_storeu_ps(gates_data, _IFOG); -#else // __SSE2__ - gates_data[0] = I; - gates_data[1] = F; - gates_data[2] = O; - gates_data[3] = G; -#endif // __SSE2__ - } - - // lstm unit - // sigmoid(I) - // sigmoid(F) - // sigmoid(O) - // tanh(G) - // c_t := f_t .* c_{t-1} + i_t .* g_t - // h_t := o_t .* tanh[c_t] - float* output_data = top_blob.row(ti); - - float* cell_ptr = cell_state; - float* hidden_ptr = hidden_state; - float* tmp_hidden_ptr = tmp_hidden_state; - -#if __SSE2__ - nn_hidden_size = hidden_size >> 2; - remain_hidden_size_start = nn_hidden_size << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_hidden_size; qq++) - { - int q = qq * 4; - - const float* gates_data = gates.row(q); - - __m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data); - __m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4); - __m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8); - __m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12); - - _MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3); - - __m128 _lstm_I = sigmoid_sse(_IFOG_4x4_0); - __m128 _lstm_F = sigmoid_sse(_IFOG_4x4_1); - __m128 _lstm_O = sigmoid_sse(_IFOG_4x4_2); - __m128 _lstm_G = tanh_sse(_IFOG_4x4_3); - - __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_lstm_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_lstm_I, _lstm_G)); - __m128 _lstm_H = _mm_mul_ps(_lstm_O, tanh_sse(_cell2)); - - _mm_storeu_ps(cell_ptr + q, _cell2); - - if (num_output == hidden_size) - { - _mm_storeu_ps(hidden_ptr + q, _lstm_H); - _mm_storeu_ps(output_data + q, _lstm_H); - } - else - { - _mm_storeu_ps(tmp_hidden_ptr + q, _lstm_H); - } - } -#else // __SSE2__ - remain_hidden_size_start = 0; -#endif // __SSE2__ - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_hidden_size_start; q < hidden_size; q++) - { - const float* gates_data = gates.row(q); - - float I = gates_data[0]; - float F = gates_data[1]; - float O = gates_data[2]; - float G = gates_data[3]; - - I = 1.f / (1.f + expf(-I)); - F = 1.f / (1.f + expf(-F)); - O = 1.f / (1.f + expf(-O)); - G = tanhf(G); - - float cell2 = F * cell_ptr[q] + I * G; - float H = O * tanhf(cell2); - - cell_ptr[q] = cell2; - if (num_output == hidden_size) - { - hidden_ptr[q] = H; - output_data[q] = H; - } - else - { - tmp_hidden_ptr[q] = H; - } - } - - if (num_output != hidden_size) - { - // int nn_num_output = num_output >> 2; - // int remain_num_output_start = nn_num_output << 2; - // #pragma omp parallel for num_threads(opt.num_threads) - // for (int qq = 0; qq < nn_num_output; qq++) - // { - // int q = qq * 4; - // - // } - int remain_num_output_start = 0; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const float* hr = weight_hr.row(q); - const float* tmp_hidden_ptr = tmp_hidden_state; - - float H = 0; - for (int i = 0; i < hidden_size; i++) - { - H += tmp_hidden_ptr[i] * hr[i]; - } - - output_data[q] = H; - hidden_ptr[q] = H; - } - } - } - - return 0; -} - -int LSTM_x86::create_pipeline_int8(const Option& opt) -{ - // pack IFOG - const int num_directions = direction == 2 ? 2 : 1; - const int size = weight_data_size / num_directions / hidden_size / 4; - -#if __AVX__ - weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 8u, 8); - bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); - weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 8u, 8); - weight_xc_data_int8_descales_packed.create(8, hidden_size / 2 + hidden_size % 2, num_directions); - weight_hc_data_int8_descales_packed.create(8, hidden_size / 2 + hidden_size % 2, num_directions); -#else - weight_xc_data_packed.create(size, hidden_size, num_directions, 4u, 4); - bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); - weight_hc_data_packed.create(num_output, hidden_size, num_directions, 4u, 4); - weight_xc_data_int8_descales_packed.create(4, hidden_size, num_directions); - weight_hc_data_int8_descales_packed.create(4, hidden_size, num_directions); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int dr = 0; dr < num_directions; dr++) - { - const Mat weight_xc = weight_xc_data.channel(dr); - const Mat bias_c = bias_c_data.channel(dr); - const Mat weight_hc = weight_hc_data.channel(dr); - const float* weight_xc_int8_scales = weight_xc_data_int8_scales.row(dr); - const float* weight_hc_int8_scales = weight_hc_data_int8_scales.row(dr); - - Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); - Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr); - Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); - Mat weight_xc_data_int8_descales_packed_dr = weight_xc_data_int8_descales_packed.channel(dr); - Mat weight_hc_data_int8_descales_packed_dr = weight_hc_data_int8_descales_packed.channel(dr); - - const float* bias_c_I = bias_c.row(0); - const float* bias_c_F = bias_c.row(1); - const float* bias_c_O = bias_c.row(2); - const float* bias_c_G = bias_c.row(3); - - float* bias_c_IFOG = bias_c_data_packed_dr.row(0); - - int q = 0; -#if __AVX__ - for (; q + 1 < hidden_size; q += 2) - { - bias_c_IFOG[0] = bias_c_I[q]; - bias_c_IFOG[1] = bias_c_F[q]; - bias_c_IFOG[2] = bias_c_O[q]; - bias_c_IFOG[3] = bias_c_G[q]; - bias_c_IFOG[4] = bias_c_I[q + 1]; - bias_c_IFOG[5] = bias_c_F[q + 1]; - bias_c_IFOG[6] = bias_c_O[q + 1]; - bias_c_IFOG[7] = bias_c_G[q + 1]; - - bias_c_IFOG += 8; - - const signed char* weight_xc_I = weight_xc.row(hidden_size * 0 + q); - const signed char* weight_xc_F = weight_xc.row(hidden_size * 1 + q); - const signed char* weight_xc_O = weight_xc.row(hidden_size * 2 + q); - const signed char* weight_xc_G = weight_xc.row(hidden_size * 3 + q); - const signed char* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1); - const signed char* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1); - const signed char* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1); - const signed char* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1); - - const signed char* weight_hc_I = weight_hc.row(hidden_size * 0 + q); - const signed char* weight_hc_F = weight_hc.row(hidden_size * 1 + q); - const signed char* weight_hc_O = weight_hc.row(hidden_size * 2 + q); - const signed char* weight_hc_G = weight_hc.row(hidden_size * 3 + q); - const signed char* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1); - const signed char* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1); - const signed char* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1); - const signed char* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1); - - signed char* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2); - signed char* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2); - float* weight_xc_int8_descales_IFOG = weight_xc_data_int8_descales_packed_dr.row(q / 2); - float* weight_hc_int8_descales_IFOG = weight_hc_data_int8_descales_packed_dr.row(q / 2); - - for (int i = 0; i < size; i++) - { - weight_xc_IFOG[0] = weight_xc_I[i]; - weight_xc_IFOG[1] = weight_xc_F[i]; - weight_xc_IFOG[2] = weight_xc_O[i]; - weight_xc_IFOG[3] = weight_xc_G[i]; - weight_xc_IFOG[4] = weight_xc_I_1[i]; - weight_xc_IFOG[5] = weight_xc_F_1[i]; - weight_xc_IFOG[6] = weight_xc_O_1[i]; - weight_xc_IFOG[7] = weight_xc_G_1[i]; - - weight_xc_IFOG += 8; - } - - for (int i = 0; i < num_output; i++) - { - weight_hc_IFOG[0] = weight_hc_I[i]; - weight_hc_IFOG[1] = weight_hc_F[i]; - weight_hc_IFOG[2] = weight_hc_O[i]; - weight_hc_IFOG[3] = weight_hc_G[i]; - weight_hc_IFOG[4] = weight_hc_I_1[i]; - weight_hc_IFOG[5] = weight_hc_F_1[i]; - weight_hc_IFOG[6] = weight_hc_O_1[i]; - weight_hc_IFOG[7] = weight_hc_G_1[i]; - - weight_hc_IFOG += 8; - } - - weight_xc_int8_descales_IFOG[0] = 1.f / weight_xc_int8_scales[hidden_size * 0 + q]; - weight_xc_int8_descales_IFOG[1] = 1.f / weight_xc_int8_scales[hidden_size * 1 + q]; - weight_xc_int8_descales_IFOG[2] = 1.f / weight_xc_int8_scales[hidden_size * 2 + q]; - weight_xc_int8_descales_IFOG[3] = 1.f / weight_xc_int8_scales[hidden_size * 3 + q]; - weight_xc_int8_descales_IFOG[4] = 1.f / weight_xc_int8_scales[hidden_size * 0 + q + 1]; - weight_xc_int8_descales_IFOG[5] = 1.f / weight_xc_int8_scales[hidden_size * 1 + q + 1]; - weight_xc_int8_descales_IFOG[6] = 1.f / weight_xc_int8_scales[hidden_size * 2 + q + 1]; - weight_xc_int8_descales_IFOG[7] = 1.f / weight_xc_int8_scales[hidden_size * 3 + q + 1]; - - weight_hc_int8_descales_IFOG[0] = 1.f / weight_hc_int8_scales[hidden_size * 0 + q]; - weight_hc_int8_descales_IFOG[1] = 1.f / weight_hc_int8_scales[hidden_size * 1 + q]; - weight_hc_int8_descales_IFOG[2] = 1.f / weight_hc_int8_scales[hidden_size * 2 + q]; - weight_hc_int8_descales_IFOG[3] = 1.f / weight_hc_int8_scales[hidden_size * 3 + q]; - weight_hc_int8_descales_IFOG[4] = 1.f / weight_hc_int8_scales[hidden_size * 0 + q + 1]; - weight_hc_int8_descales_IFOG[5] = 1.f / weight_hc_int8_scales[hidden_size * 1 + q + 1]; - weight_hc_int8_descales_IFOG[6] = 1.f / weight_hc_int8_scales[hidden_size * 2 + q + 1]; - weight_hc_int8_descales_IFOG[7] = 1.f / weight_hc_int8_scales[hidden_size * 3 + q + 1]; - } -#endif // __AVX__ - for (; q < hidden_size; q++) - { - bias_c_IFOG[0] = bias_c_I[q]; - bias_c_IFOG[1] = bias_c_F[q]; - bias_c_IFOG[2] = bias_c_O[q]; - bias_c_IFOG[3] = bias_c_G[q]; - - bias_c_IFOG += 4; - - const signed char* weight_xc_I = weight_xc.row(hidden_size * 0 + q); - const signed char* weight_xc_F = weight_xc.row(hidden_size * 1 + q); - const signed char* weight_xc_O = weight_xc.row(hidden_size * 2 + q); - const signed char* weight_xc_G = weight_xc.row(hidden_size * 3 + q); - - const signed char* weight_hc_I = weight_hc.row(hidden_size * 0 + q); - const signed char* weight_hc_F = weight_hc.row(hidden_size * 1 + q); - const signed char* weight_hc_O = weight_hc.row(hidden_size * 2 + q); - const signed char* weight_hc_G = weight_hc.row(hidden_size * 3 + q); - -#if __AVX__ - signed char* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2 + q % 2); - signed char* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2 + q % 2); - float* weight_xc_int8_descales_IFOG = weight_xc_data_int8_descales_packed_dr.row(q / 2 + q % 2); - float* weight_hc_int8_descales_IFOG = weight_hc_data_int8_descales_packed_dr.row(q / 2 + q % 2); -#else - signed char* weight_xc_IFOG = weight_xc_data_packed_dr.row(q); - signed char* weight_hc_IFOG = weight_hc_data_packed_dr.row(q); - float* weight_xc_int8_descales_IFOG = weight_xc_data_int8_descales_packed_dr.row(q); - float* weight_hc_int8_descales_IFOG = weight_hc_data_int8_descales_packed_dr.row(q); -#endif - - for (int i = 0; i < size; i++) - { - weight_xc_IFOG[0] = weight_xc_I[i]; - weight_xc_IFOG[1] = weight_xc_F[i]; - weight_xc_IFOG[2] = weight_xc_O[i]; - weight_xc_IFOG[3] = weight_xc_G[i]; - - weight_xc_IFOG += 4; - } - - for (int i = 0; i < num_output; i++) - { - weight_hc_IFOG[0] = weight_hc_I[i]; - weight_hc_IFOG[1] = weight_hc_F[i]; - weight_hc_IFOG[2] = weight_hc_O[i]; - weight_hc_IFOG[3] = weight_hc_G[i]; - - weight_hc_IFOG += 4; - } - - weight_xc_int8_descales_IFOG[0] = 1.f / weight_xc_int8_scales[hidden_size * 0 + q]; - weight_xc_int8_descales_IFOG[1] = 1.f / weight_xc_int8_scales[hidden_size * 1 + q]; - weight_xc_int8_descales_IFOG[2] = 1.f / weight_xc_int8_scales[hidden_size * 2 + q]; - weight_xc_int8_descales_IFOG[3] = 1.f / weight_xc_int8_scales[hidden_size * 3 + q]; - - weight_hc_int8_descales_IFOG[0] = 1.f / weight_hc_int8_scales[hidden_size * 0 + q]; - weight_hc_int8_descales_IFOG[1] = 1.f / weight_hc_int8_scales[hidden_size * 1 + q]; - weight_hc_int8_descales_IFOG[2] = 1.f / weight_hc_int8_scales[hidden_size * 2 + q]; - weight_hc_int8_descales_IFOG[3] = 1.f / weight_hc_int8_scales[hidden_size * 3 + q]; - } - } - - if (opt.lightmode) - { - weight_xc_data.release(); - bias_c_data.release(); - weight_hc_data.release(); - weight_xc_data_int8_scales.release(); - weight_hc_data_int8_scales.release(); - } - - return 0; -} -#endif // NCNN_INT8 - -int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const -{ - int T = bottom_blob.h; - - int num_directions = direction == 2 ? 2 : 1; + int num_directions = direction == 2 ? 2 : 1; // initial hidden state Mat hidden(num_output, 4u, opt.workspace_allocator); @@ -1251,23 +807,20 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob.empty()) return -100; + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + // Uni directional if (direction == 0 || direction == 1) { -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; - } + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); } if (direction == 2) @@ -1280,37 +833,15 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob_reverse.empty()) return -100; -#if NCNN_INT8 - if (int8_scale_term) { - int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); } - hidden.fill(0.0f); + hidden.fill(0.f); cell.fill(0.0f); -#if NCNN_INT8 - if (int8_scale_term) { - int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret != 0) - return ret; + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); } // concat w @@ -1328,9 +859,10 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) return 0; } -int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +int LSTM_x86::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { const Mat& bottom_blob = bottom_blobs[0]; + int T = bottom_blob.h; int num_directions = direction == 2 ? 2 : 1; @@ -1360,23 +892,20 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to if (top_blob.empty()) return -100; + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + // Uni directional if (direction == 0 || direction == 1) { -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; - } + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); } if (direction == 2) @@ -1391,36 +920,14 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); -#if NCNN_INT8 - if (int8_scale_term) { - int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); - if (ret != 0) - return ret; + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); } Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); - if (ret != 0) - return ret; - } - else -#endif { - int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); - if (ret != 0) - return ret; + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); } // concat w @@ -1443,5 +950,6 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to return 0; } +#endif // NCNN_INT8 } // namespace ncnn diff --git a/src/layer/x86/lstm_x86.h b/src/layer/x86/lstm_x86.h index f0785fe4425..d31b7377ccf 100644 --- a/src/layer/x86/lstm_x86.h +++ b/src/layer/x86/lstm_x86.h @@ -33,6 +33,9 @@ class LSTM_x86 : public LSTM protected: #if NCNN_INT8 int create_pipeline_int8(const Option& opt); + void dynamic_quantize(const Mat& bottom_blob, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const; + int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif public: @@ -40,9 +43,10 @@ class LSTM_x86 : public LSTM Mat bias_c_data_packed; Mat weight_hc_data_packed; + Mat weight_data_tm; + #if NCNN_INT8 - Mat weight_hc_data_int8_descales_packed; - Mat weight_xc_data_int8_descales_packed; + Mat weight_data_tm_int8_descales; #endif };