From 1a2d242d93eadea88f1de3b93eb307d39cb1acd3 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 23 Apr 2024 17:35:41 +0800 Subject: [PATCH] lstm arm int8 --- src/layer/arm/lstm_arm.cpp | 1393 ++++++++++++++++++++++------ src/layer/arm/lstm_arm.h | 8 + src/layer/arm/lstm_arm_asimdhp.cpp | 1159 ++++++++++++++++++++++- 3 files changed, 2236 insertions(+), 324 deletions(-) diff --git a/src/layer/arm/lstm_arm.cpp b/src/layer/arm/lstm_arm.cpp index 62a4860aaf3..c41aaacd17b 100644 --- a/src/layer/arm/lstm_arm.cpp +++ b/src/layer/arm/lstm_arm.cpp @@ -40,45 +40,6 @@ LSTM_arm::LSTM_arm() int LSTM_arm::create_pipeline(const Option& opt) { -#if NCNN_INT8 - if (int8_scale_term) - { - const int num_directions = direction == 2 ? 2 : 1; - const int size = weight_data_size / num_directions / hidden_size / 4; - - // TODO fuse weight de-scale into kernel - Mat weight_xc_data_fp32(size, hidden_size * 4, num_directions); - Mat weight_hc_data_fp32(num_output, hidden_size * 4, num_directions); - for (int d = 0; d < num_directions; d++) - { - for (int q = 0; q < hidden_size * 4; q++) - { - const signed char* weight_xc_ptr = weight_xc_data.channel(d).row(q); - const signed char* weight_hc_ptr = weight_hc_data.channel(d).row(q); - - float* weight_xc_fp32_ptr = weight_xc_data_fp32.channel(d).row(q); - float* weight_hc_fp32_ptr = weight_hc_data_fp32.channel(d).row(q); - - const float descale_xc = 1.f / weight_xc_data_int8_scales.row(d)[q]; - const float descale_hc = 1.f / weight_hc_data_int8_scales.row(d)[q]; - - for (int i = 0; i < size; i++) - { - weight_xc_fp32_ptr[i] = weight_xc_ptr[i] * descale_xc; - } - - for (int i = 0; i < num_output; i++) - { - weight_hc_fp32_ptr[i] = weight_hc_ptr[i] * descale_hc; - } - } - } - - weight_xc_data = weight_xc_data_fp32; - weight_hc_data = weight_hc_data_fp32; - } -#endif // NCNN_INT8 - #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage) { @@ -93,6 +54,13 @@ int LSTM_arm::create_pipeline(const Option& opt) } #endif +#if NCNN_INT8 + if (int8_scale_term) + { + return create_pipeline_int8(opt); + } +#endif + // pack IFOG int num_directions = direction == 2 ? 2 : 1; int size = weight_data_size / num_directions / hidden_size / 4; @@ -347,7 +315,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w float* tmp_hidden_ptr = tmp_hidden_state; int remain_hidden_size_start = 0; -#if 0 //__ARM_NEON TODO test_lstm failed for precision loss +#if __ARM_NEON int nn_hidden_size = hidden_size >> 2; remain_hidden_size_start = nn_hidden_size << 2; @@ -443,190 +411,8 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w return 0; } -int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const -{ - int elembits = bottom_blob.elembits(); - -#if NCNN_ARM82 - if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blob, top_blob, opt); - else - return forward_fp16s(bottom_blob, top_blob, opt); - } -#endif - -#if NCNN_BF16 - if (opt.use_bf16_storage && elembits == 16) - return forward_bf16s(bottom_blob, top_blob, opt); -#endif - - int T = bottom_blob.h; - - int num_directions = direction == 2 ? 2 : 1; - - // initial hidden state - Mat hidden(num_output, 4u, opt.workspace_allocator); - if (hidden.empty()) - return -100; - hidden.fill(0.f); - - Mat cell(hidden_size, 4u, opt.workspace_allocator); - if (cell.empty()) - return -100; - cell.fill(0.f); - - top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); - if (top_blob.empty()) - return -100; - - // Uni directional - if (direction == 0 || direction == 1) - { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; - } - - if (direction == 2) - { - Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); - if (top_blob_forward.empty()) - return -100; - - Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); - if (top_blob_reverse.empty()) - return -100; - - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; - - hidden.fill(0.0f); - cell.fill(0.0f); - - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; - - // concat w - for (int i = 0; i < T; i++) - { - const float* pf = top_blob_forward.row(i); - const float* pr = top_blob_reverse.row(i); - float* ptr = top_blob.row(i); - - memcpy(ptr, pf, num_output * sizeof(float)); - memcpy(ptr + num_output, pr, num_output * sizeof(float)); - } - } - - return 0; -} - -int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const -{ - const Mat& bottom_blob = bottom_blobs[0]; - int elembits = bottom_blob.elembits(); - -#if NCNN_ARM82 - if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blobs, top_blobs, opt); - else - return forward_fp16s(bottom_blobs, top_blobs, opt); - } -#endif - -#if NCNN_BF16 - if (opt.use_bf16_storage && elembits == 16) - return forward_bf16s(bottom_blobs, top_blobs, opt); -#endif - - int T = bottom_blob.h; - int num_directions = direction == 2 ? 2 : 1; - - Mat hidden; - Mat cell; - Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; - if (bottom_blobs.size() == 3) - { - hidden = bottom_blobs[1].clone(hidden_cell_allocator); - cell = bottom_blobs[2].clone(hidden_cell_allocator); - } - else - { - hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); - if (hidden.empty()) - return -100; - hidden.fill(0.f); - - cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); - if (cell.empty()) - return -100; - cell.fill(0.f); - } - - Mat& top_blob = top_blobs[0]; - top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); - if (top_blob.empty()) - return -100; - - // Uni directional - if (direction == 0 || direction == 1) - { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; - } - - if (direction == 2) - { - Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); - if (top_blob_forward.empty()) - return -100; - - Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); - if (top_blob_reverse.empty()) - return -100; - - Mat hidden0 = hidden.row_range(0, 1); - Mat cell0 = cell.row_range(0, 1); - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); - if (ret0 != 0) - return ret0; - - Mat hidden1 = hidden.row_range(1, 1); - Mat cell1 = cell.row_range(1, 1); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); - if (ret1 != 0) - return ret1; - - // concat w - for (int i = 0; i < T; i++) - { - const float* pf = top_blob_forward.row(i); - const float* pr = top_blob_reverse.row(i); - float* ptr = top_blob.row(i); - - memcpy(ptr, pf, num_output * sizeof(float)); - memcpy(ptr + num_output, pr, num_output * sizeof(float)); - } - } - - if (top_blobs.size() == 3) - { - top_blobs[1] = hidden; - top_blobs[2] = cell; - } - - return 0; -} - -#if NCNN_BF16 -static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +#if NCNN_INT8 +static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; @@ -659,39 +445,60 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const int ti = reverse ? T - 1 - t : t; - const unsigned short* x = bottom_blob.row(ti); + const float* x = bottom_blob.row(ti); #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < hidden_size; q++) { - const unsigned short* bias_c_IFOG = (const unsigned short*)bias_c + q * 4; + const float* bias_c_IFOG = (const float*)bias_c + q * 4; // gate I F O G - const unsigned short* weight_xc_IFOG = weight_xc.row(q); - - const unsigned short* weight_hc_IFOG = weight_hc.row(q); + const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q); + const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q); + const float* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q); + const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q); #if __ARM_NEON - float32x4_t _IFOG = bfloat2float(vld1_u16(bias_c_IFOG)); + float32x4_t _descale_xc_IFOG = vld1q_f32(weight_xc_int8_descales_IFOG); + float32x4_t _descale_hc_IFOG = vld1q_f32(weight_hc_int8_descales_IFOG); + + float32x4_t _IFOG = vld1q_f32(bias_c_IFOG); float32x4_t _sum1 = vdupq_n_f32(0.f); float32x4_t _sum2 = vdupq_n_f32(0.f); float32x4_t _sum3 = vdupq_n_f32(0.f); #else - float I = bfloat16_to_float32(bias_c_IFOG[0]); - float F = bfloat16_to_float32(bias_c_IFOG[1]); - float O = bfloat16_to_float32(bias_c_IFOG[2]); - float G = bfloat16_to_float32(bias_c_IFOG[3]); + const float descale_xc_I = weight_xc_int8_descales_IFOG[0]; + const float descale_xc_F = weight_xc_int8_descales_IFOG[1]; + const float descale_xc_O = weight_xc_int8_descales_IFOG[2]; + const float descale_xc_G = weight_xc_int8_descales_IFOG[3]; + + const float descale_hc_I = weight_hc_int8_descales_IFOG[0]; + const float descale_hc_F = weight_hc_int8_descales_IFOG[1]; + const float descale_hc_O = weight_hc_int8_descales_IFOG[2]; + const float descale_hc_G = weight_hc_int8_descales_IFOG[3]; + + float I = bias_c_IFOG[0]; + float F = bias_c_IFOG[1]; + float O = bias_c_IFOG[2]; + float G = bias_c_IFOG[3]; #endif // __ARM_NEON int i = 0; #if __ARM_NEON for (; i + 3 < size; i += 4) { - float32x4_t _xi = bfloat2float(vld1_u16(x + i)); + float32x4_t _xi = vld1q_f32(x + i); - float32x4_t _weight_xc_IFOG_0 = bfloat2float(vld1_u16(weight_xc_IFOG)); - float32x4_t _weight_xc_IFOG_1 = bfloat2float(vld1_u16(weight_xc_IFOG + 4)); - float32x4_t _weight_xc_IFOG_2 = bfloat2float(vld1_u16(weight_xc_IFOG + 8)); - float32x4_t _weight_xc_IFOG_3 = bfloat2float(vld1_u16(weight_xc_IFOG + 12)); + int8x16_t _weight_xc_IFOG = vld1q_s8(weight_xc_int8_IFOG); + int16x8_t _weight_xc_IFOG01 = vmovl_s8(vget_low_s8(_weight_xc_IFOG)); + int16x8_t _weight_xc_IFOG23 = vmovl_s8(vget_high_s8(_weight_xc_IFOG)); + float32x4_t _weight_xc_IFOG_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_IFOG01))); + float32x4_t _weight_xc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG01))); + float32x4_t _weight_xc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_IFOG23))); + float32x4_t _weight_xc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG23))); + _weight_xc_IFOG_0 = vmulq_f32(_weight_xc_IFOG_0, _descale_xc_IFOG); + _weight_xc_IFOG_1 = vmulq_f32(_weight_xc_IFOG_1, _descale_xc_IFOG); + _weight_xc_IFOG_2 = vmulq_f32(_weight_xc_IFOG_2, _descale_xc_IFOG); + _weight_xc_IFOG_3 = vmulq_f32(_weight_xc_IFOG_3, _descale_xc_IFOG); #if __aarch64__ _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0); @@ -705,27 +512,26 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_IFOG_3, vget_high_f32(_xi), 1); #endif - weight_xc_IFOG += 16; + weight_xc_int8_IFOG += 16; } #endif // __ARM_NEON for (; i < size; i++) { -#if __ARM_NEON - unsigned short xi = x[i]; + float xi = x[i]; - float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); - float32x4_t _weight_xc_IFOG = bfloat2float(vld1_u16(weight_xc_IFOG)); +#if __ARM_NEON + float32x4_t _xi = vdupq_n_f32(xi); + float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG))))); + _weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc_IFOG); _IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi); #else - float xi = bfloat16_to_float32(x[i]); - - I += bfloat16_to_float32(weight_xc_IFOG[0]) * xi; - F += bfloat16_to_float32(weight_xc_IFOG[1]) * xi; - O += bfloat16_to_float32(weight_xc_IFOG[2]) * xi; - G += bfloat16_to_float32(weight_xc_IFOG[3]) * xi; + I += weight_xc_int8_IFOG[0] * descale_xc_I * xi; + F += weight_xc_int8_IFOG[1] * descale_xc_F * xi; + O += weight_xc_int8_IFOG[2] * descale_xc_O * xi; + G += weight_xc_int8_IFOG[3] * descale_xc_G * xi; #endif // __ARM_NEON - weight_xc_IFOG += 4; + weight_xc_int8_IFOG += 4; } i = 0; @@ -734,10 +540,17 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const { float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - float32x4_t _weight_hc_IFOG_0 = bfloat2float(vld1_u16(weight_hc_IFOG)); - float32x4_t _weight_hc_IFOG_1 = bfloat2float(vld1_u16(weight_hc_IFOG + 4)); - float32x4_t _weight_hc_IFOG_2 = bfloat2float(vld1_u16(weight_hc_IFOG + 8)); - float32x4_t _weight_hc_IFOG_3 = bfloat2float(vld1_u16(weight_hc_IFOG + 12)); + int8x16_t _weight_hc_IFOG = vld1q_s8(weight_hc_int8_IFOG); + int16x8_t _weight_hc_IFOG01 = vmovl_s8(vget_low_s8(_weight_hc_IFOG)); + int16x8_t _weight_hc_IFOG23 = vmovl_s8(vget_high_s8(_weight_hc_IFOG)); + float32x4_t _weight_hc_IFOG_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_IFOG01))); + float32x4_t _weight_hc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG01))); + float32x4_t _weight_hc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_IFOG23))); + float32x4_t _weight_hc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG23))); + _weight_hc_IFOG_0 = vmulq_f32(_weight_hc_IFOG_0, _descale_hc_IFOG); + _weight_hc_IFOG_1 = vmulq_f32(_weight_hc_IFOG_1, _descale_hc_IFOG); + _weight_hc_IFOG_2 = vmulq_f32(_weight_hc_IFOG_2, _descale_hc_IFOG); + _weight_hc_IFOG_3 = vmulq_f32(_weight_hc_IFOG_3, _descale_hc_IFOG); #if __aarch64__ _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0); @@ -751,7 +564,7 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_IFOG_3, vget_high_f32(_h_cont), 1); #endif - weight_hc_IFOG += 16; + weight_hc_int8_IFOG += 16; } #endif // __ARM_NEON for (; i < num_output; i++) @@ -760,16 +573,17 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const #if __ARM_NEON float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_IFOG = bfloat2float(vld1_u16(weight_hc_IFOG)); + float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG))))); + _weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc_IFOG); _IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont); #else - I += bfloat16_to_float32(weight_hc_IFOG[0]) * h_cont; - F += bfloat16_to_float32(weight_hc_IFOG[1]) * h_cont; - O += bfloat16_to_float32(weight_hc_IFOG[2]) * h_cont; - G += bfloat16_to_float32(weight_hc_IFOG[3]) * h_cont; + I += weight_hc_int8_IFOG[0] * descale_hc_I * h_cont; + F += weight_hc_int8_IFOG[1] * descale_hc_F * h_cont; + O += weight_hc_int8_IFOG[2] * descale_hc_O * h_cont; + G += weight_hc_int8_IFOG[3] * descale_hc_G * h_cont; #endif // __ARM_NEON - weight_hc_IFOG += 4; + weight_hc_int8_IFOG += 4; } float* gates_data = gates.row(q); @@ -795,7 +609,7 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const // tanh(G) // c_t := f_t .* c_{t-1} + i_t .* g_t // h_t := o_t .* tanh[c_t] - unsigned short* output_data = top_blob.row(ti); + float* output_data = top_blob.row(ti); float* cell_ptr = cell_state; float* hidden_ptr = hidden_state; @@ -828,7 +642,7 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const if (num_output == hidden_size) { vst1q_f32(hidden_ptr + q, _lstm_H); - vst1_u16(output_data + q, float2bfloat(_lstm_H)); + vst1q_f32(output_data + q, _lstm_H); } else { @@ -858,7 +672,7 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const if (num_output == hidden_size) { hidden_ptr[q] = H; - output_data[q] = float32_to_bfloat16(H); + output_data[q] = H; } else { @@ -890,7 +704,7 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const } hidden_ptr[q] = H; - output_data[q] = float32_to_bfloat16(H); + output_data[q] = H; } } } @@ -898,10 +712,951 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const return 0; } -int LSTM_arm::create_pipeline_bf16s(const Option& opt) +int LSTM_arm::create_pipeline_int8(const Option& opt) { // pack IFOG - int num_directions = direction == 2 ? 2 : 1; + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / hidden_size / 4; + + weight_xc_data_packed.create(size, hidden_size, num_directions, 4u, 4); + bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); + weight_hc_data_packed.create(num_output, hidden_size, num_directions, 4u, 4); + weight_xc_data_int8_descales_packed.create(4, hidden_size, num_directions); + weight_hc_data_int8_descales_packed.create(4, hidden_size, num_directions); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc = weight_xc_data.channel(dr); + const Mat bias_c = bias_c_data.channel(dr); + const Mat weight_hc = weight_hc_data.channel(dr); + const float* weight_xc_int8_scales = weight_xc_data_int8_scales.row(dr); + const float* weight_hc_int8_scales = weight_hc_data_int8_scales.row(dr); + + Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); + Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr); + Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); + Mat weight_xc_data_int8_descales_packed_dr = weight_xc_data_int8_descales_packed.channel(dr); + Mat weight_hc_data_int8_descales_packed_dr = weight_hc_data_int8_descales_packed.channel(dr); + + const float* bias_c_I = bias_c.row(0); + const float* bias_c_F = bias_c.row(1); + const float* bias_c_O = bias_c.row(2); + const float* bias_c_G = bias_c.row(3); + + float* bias_c_IFOG = bias_c_data_packed_dr.row(0); + + for (int q = 0; q < hidden_size; q++) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + + bias_c_IFOG += 4; + + const signed char* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const signed char* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const signed char* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const signed char* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + + const signed char* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const signed char* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const signed char* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const signed char* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + + signed char* weight_xc_IFOG = weight_xc_data_packed_dr.row(q); + signed char* weight_hc_IFOG = weight_hc_data_packed_dr.row(q); + float* weight_xc_int8_descales_IFOG = weight_xc_data_int8_descales_packed_dr.row(q); + float* weight_hc_int8_descales_IFOG = weight_hc_data_int8_descales_packed_dr.row(q); + + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = weight_xc_I[i]; + weight_xc_IFOG[1] = weight_xc_F[i]; + weight_xc_IFOG[2] = weight_xc_O[i]; + weight_xc_IFOG[3] = weight_xc_G[i]; + + weight_xc_IFOG += 4; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = weight_hc_I[i]; + weight_hc_IFOG[1] = weight_hc_F[i]; + weight_hc_IFOG[2] = weight_hc_O[i]; + weight_hc_IFOG[3] = weight_hc_G[i]; + + weight_hc_IFOG += 4; + } + + weight_xc_int8_descales_IFOG[0] = 1.f / weight_xc_int8_scales[hidden_size * 0 + q]; + weight_xc_int8_descales_IFOG[1] = 1.f / weight_xc_int8_scales[hidden_size * 1 + q]; + weight_xc_int8_descales_IFOG[2] = 1.f / weight_xc_int8_scales[hidden_size * 2 + q]; + weight_xc_int8_descales_IFOG[3] = 1.f / weight_xc_int8_scales[hidden_size * 3 + q]; + + weight_hc_int8_descales_IFOG[0] = 1.f / weight_hc_int8_scales[hidden_size * 0 + q]; + weight_hc_int8_descales_IFOG[1] = 1.f / weight_hc_int8_scales[hidden_size * 1 + q]; + weight_hc_int8_descales_IFOG[2] = 1.f / weight_hc_int8_scales[hidden_size * 2 + q]; + weight_hc_int8_descales_IFOG[3] = 1.f / weight_hc_int8_scales[hidden_size * 3 + q]; + } + } + + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + weight_xc_data_int8_scales.release(); + weight_hc_data_int8_scales.release(); + } + + return 0; +} +#endif // NCNN_INT8 + +int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int elembits = bottom_blob.elembits(); + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + if (opt.use_fp16_arithmetic) + return forward_fp16sa(bottom_blob, top_blob, opt); + else + return forward_fp16s(bottom_blob, top_blob, opt); + } +#endif + +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_bf16s(bottom_blob, top_blob, opt); +#endif + + int T = bottom_blob.h; + + int num_directions = direction == 2 ? 2 : 1; + + // initial hidden state + Mat hidden(num_output, 4u, opt.workspace_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + Mat cell(hidden_size, 4u, opt.workspace_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // Uni directional + if (direction == 0 || direction == 1) + { +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + + hidden.fill(0.0f); + cell.fill(0.0f); + +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + return 0; +} + +int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + int elembits = bottom_blob.elembits(); + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + if (opt.use_fp16_arithmetic) + return forward_fp16sa(bottom_blobs, top_blobs, opt); + else + return forward_fp16s(bottom_blobs, top_blobs, opt); + } +#endif + +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_bf16s(bottom_blobs, top_blobs, opt); +#endif + + int T = bottom_blob.h; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Mat cell; + Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 3) + { + hidden = bottom_blobs[1].clone(hidden_cell_allocator); + cell = bottom_blobs[2].clone(hidden_cell_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // Uni directional + if (direction == 0 || direction == 1) + { +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + Mat cell0 = cell.row_range(0, 1); +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } + + Mat hidden1 = hidden.row_range(1, 1); + Mat cell1 = cell.row_range(1, 1); +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 3) + { + top_blobs[1] = hidden; + top_blobs[2] = cell; + } + + return 0; +} + +#if NCNN_BF16 +static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + + // unroll + for (int t = 0; t < T; t++) + { + // clip hidden by continuation indicator + // h_cont_{t-1} = cont_t * h_{t-1} + // h_cont_{t-1} = h_{t-1} if cont_t == 1 + // 0 otherwise + // calculate hidden + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + + int ti = reverse ? T - 1 - t : t; + + const unsigned short* x = bottom_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const unsigned short* bias_c_IFOG = (const unsigned short*)bias_c + q * 4; + + // gate I F O G + const unsigned short* weight_xc_IFOG = weight_xc.row(q); + + const unsigned short* weight_hc_IFOG = weight_hc.row(q); + +#if __ARM_NEON + float32x4_t _IFOG = bfloat2float(vld1_u16(bias_c_IFOG)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); +#else + float I = bfloat16_to_float32(bias_c_IFOG[0]); + float F = bfloat16_to_float32(bias_c_IFOG[1]); + float O = bfloat16_to_float32(bias_c_IFOG[2]); + float G = bfloat16_to_float32(bias_c_IFOG[3]); +#endif // __ARM_NEON + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _xi = bfloat2float(vld1_u16(x + i)); + + float32x4_t _weight_xc_IFOG_0 = bfloat2float(vld1_u16(weight_xc_IFOG)); + float32x4_t _weight_xc_IFOG_1 = bfloat2float(vld1_u16(weight_xc_IFOG + 4)); + float32x4_t _weight_xc_IFOG_2 = bfloat2float(vld1_u16(weight_xc_IFOG + 8)); + float32x4_t _weight_xc_IFOG_3 = bfloat2float(vld1_u16(weight_xc_IFOG + 12)); + +#if __aarch64__ + _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3); +#else + _IFOG = vmlaq_lane_f32(_IFOG, _weight_xc_IFOG_0, vget_low_f32(_xi), 0); + _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_IFOG_1, vget_low_f32(_xi), 1); + _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_IFOG_2, vget_high_f32(_xi), 0); + _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_IFOG_3, vget_high_f32(_xi), 1); +#endif + + weight_xc_IFOG += 16; + } +#endif // __ARM_NEON + for (; i < size; i++) + { +#if __ARM_NEON + unsigned short xi = x[i]; + + float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); + float32x4_t _weight_xc_IFOG = bfloat2float(vld1_u16(weight_xc_IFOG)); + _IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi); +#else + float xi = bfloat16_to_float32(x[i]); + + I += bfloat16_to_float32(weight_xc_IFOG[0]) * xi; + F += bfloat16_to_float32(weight_xc_IFOG[1]) * xi; + O += bfloat16_to_float32(weight_xc_IFOG[2]) * xi; + G += bfloat16_to_float32(weight_xc_IFOG[3]) * xi; +#endif // __ARM_NEON + + weight_xc_IFOG += 4; + } + + i = 0; +#if __ARM_NEON + for (; i + 3 < num_output; i += 4) + { + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); + + float32x4_t _weight_hc_IFOG_0 = bfloat2float(vld1_u16(weight_hc_IFOG)); + float32x4_t _weight_hc_IFOG_1 = bfloat2float(vld1_u16(weight_hc_IFOG + 4)); + float32x4_t _weight_hc_IFOG_2 = bfloat2float(vld1_u16(weight_hc_IFOG + 8)); + float32x4_t _weight_hc_IFOG_3 = bfloat2float(vld1_u16(weight_hc_IFOG + 12)); + +#if __aarch64__ + _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3); +#else + _IFOG = vmlaq_lane_f32(_IFOG, _weight_hc_IFOG_0, vget_low_f32(_h_cont), 0); + _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_IFOG_1, vget_low_f32(_h_cont), 1); + _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_IFOG_2, vget_high_f32(_h_cont), 0); + _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_IFOG_3, vget_high_f32(_h_cont), 1); +#endif + + weight_hc_IFOG += 16; + } +#endif // __ARM_NEON + for (; i < num_output; i++) + { + float h_cont = hidden_state[i]; + +#if __ARM_NEON + float32x4_t _h_cont = vdupq_n_f32(h_cont); + float32x4_t _weight_hc_IFOG = bfloat2float(vld1_u16(weight_hc_IFOG)); + _IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont); +#else + I += bfloat16_to_float32(weight_hc_IFOG[0]) * h_cont; + F += bfloat16_to_float32(weight_hc_IFOG[1]) * h_cont; + O += bfloat16_to_float32(weight_hc_IFOG[2]) * h_cont; + G += bfloat16_to_float32(weight_hc_IFOG[3]) * h_cont; +#endif // __ARM_NEON + + weight_hc_IFOG += 4; + } + + float* gates_data = gates.row(q); + +#if __ARM_NEON + _IFOG = vaddq_f32(_IFOG, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _IFOG = vaddq_f32(_IFOG, _sum2); + + vst1q_f32(gates_data, _IFOG); +#else + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; +#endif // __ARM_NEON + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + unsigned short* output_data = top_blob.row(ti); + + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + + int remain_hidden_size_start = 0; +#if __ARM_NEON + int nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 4; + + const float* gates_data = gates.row(q); + + float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data); + + float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]); + float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]); + float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]); + float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]); + + float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); + float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); + + vst1q_f32(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _lstm_H); + vst1_u16(output_data + q, float2bfloat(_lstm_H)); + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _lstm_H); + } + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_ptr[q] + I * G; + float H = O * tanhf(cell2); + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = float32_to_bfloat16(H); + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = float32_to_bfloat16(H); + } + } + } + + return 0; +} + +#if NCNN_INT8 +static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + + // unroll + for (int t = 0; t < T; t++) + { + // clip hidden by continuation indicator + // h_cont_{t-1} = cont_t * h_{t-1} + // h_cont_{t-1} = h_{t-1} if cont_t == 1 + // 0 otherwise + // calculate hidden + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + + int ti = reverse ? T - 1 - t : t; + + const unsigned short* x = bottom_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const unsigned short* bias_c_IFOG = (const unsigned short*)bias_c + q * 4; + + // gate I F O G + const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q); + const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q); + const float* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q); + const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q); + +#if __ARM_NEON + float32x4_t _descale_xc_IFOG = vld1q_f32(weight_xc_int8_descales_IFOG); + float32x4_t _descale_hc_IFOG = vld1q_f32(weight_hc_int8_descales_IFOG); + + float32x4_t _IFOG = bfloat2float(vld1_u16(bias_c_IFOG)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); +#else + const float descale_xc_I = weight_xc_int8_descales_IFOG[0]; + const float descale_xc_F = weight_xc_int8_descales_IFOG[1]; + const float descale_xc_O = weight_xc_int8_descales_IFOG[2]; + const float descale_xc_G = weight_xc_int8_descales_IFOG[3]; + + const float descale_hc_I = weight_hc_int8_descales_IFOG[0]; + const float descale_hc_F = weight_hc_int8_descales_IFOG[1]; + const float descale_hc_O = weight_hc_int8_descales_IFOG[2]; + const float descale_hc_G = weight_hc_int8_descales_IFOG[3]; + + float I = bfloat16_to_float32(bias_c_IFOG[0]); + float F = bfloat16_to_float32(bias_c_IFOG[1]); + float O = bfloat16_to_float32(bias_c_IFOG[2]); + float G = bfloat16_to_float32(bias_c_IFOG[3]); +#endif // __ARM_NEON + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _xi = bfloat2float(vld1_u16(x + i)); + + int8x16_t _weight_xc_IFOG = vld1q_s8(weight_xc_int8_IFOG); + int16x8_t _weight_xc_IFOG01 = vmovl_s8(vget_low_s8(_weight_xc_IFOG)); + int16x8_t _weight_xc_IFOG23 = vmovl_s8(vget_high_s8(_weight_xc_IFOG)); + float32x4_t _weight_xc_IFOG_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_IFOG01))); + float32x4_t _weight_xc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG01))); + float32x4_t _weight_xc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_IFOG23))); + float32x4_t _weight_xc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG23))); + _weight_xc_IFOG_0 = vmulq_f32(_weight_xc_IFOG_0, _descale_xc_IFOG); + _weight_xc_IFOG_1 = vmulq_f32(_weight_xc_IFOG_1, _descale_xc_IFOG); + _weight_xc_IFOG_2 = vmulq_f32(_weight_xc_IFOG_2, _descale_xc_IFOG); + _weight_xc_IFOG_3 = vmulq_f32(_weight_xc_IFOG_3, _descale_xc_IFOG); + +#if __aarch64__ + _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3); +#else + _IFOG = vmlaq_lane_f32(_IFOG, _weight_xc_IFOG_0, vget_low_f32(_xi), 0); + _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_IFOG_1, vget_low_f32(_xi), 1); + _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_IFOG_2, vget_high_f32(_xi), 0); + _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_IFOG_3, vget_high_f32(_xi), 1); +#endif + + weight_xc_int8_IFOG += 16; + } +#endif // __ARM_NEON + for (; i < size; i++) + { +#if __ARM_NEON + unsigned short xi = x[i]; + + float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); + float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG))))); + _weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc_IFOG); + _IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi); +#else + float xi = bfloat16_to_float32(x[i]); + + I += weight_xc_int8_IFOG[0] * descale_xc_I * xi; + F += weight_xc_int8_IFOG[1] * descale_xc_F * xi; + O += weight_xc_int8_IFOG[2] * descale_xc_O * xi; + G += weight_xc_int8_IFOG[3] * descale_xc_G * xi; +#endif // __ARM_NEON + + weight_xc_int8_IFOG += 4; + } + + i = 0; +#if __ARM_NEON + for (; i + 3 < num_output; i += 4) + { + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); + + int8x16_t _weight_hc_IFOG = vld1q_s8(weight_hc_int8_IFOG); + int16x8_t _weight_hc_IFOG01 = vmovl_s8(vget_low_s8(_weight_hc_IFOG)); + int16x8_t _weight_hc_IFOG23 = vmovl_s8(vget_high_s8(_weight_hc_IFOG)); + float32x4_t _weight_hc_IFOG_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_IFOG01))); + float32x4_t _weight_hc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG01))); + float32x4_t _weight_hc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_IFOG23))); + float32x4_t _weight_hc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG23))); + _weight_hc_IFOG_0 = vmulq_f32(_weight_hc_IFOG_0, _descale_hc_IFOG); + _weight_hc_IFOG_1 = vmulq_f32(_weight_hc_IFOG_1, _descale_hc_IFOG); + _weight_hc_IFOG_2 = vmulq_f32(_weight_hc_IFOG_2, _descale_hc_IFOG); + _weight_hc_IFOG_3 = vmulq_f32(_weight_hc_IFOG_3, _descale_hc_IFOG); + +#if __aarch64__ + _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3); +#else + _IFOG = vmlaq_lane_f32(_IFOG, _weight_hc_IFOG_0, vget_low_f32(_h_cont), 0); + _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_IFOG_1, vget_low_f32(_h_cont), 1); + _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_IFOG_2, vget_high_f32(_h_cont), 0); + _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_IFOG_3, vget_high_f32(_h_cont), 1); +#endif + + weight_hc_int8_IFOG += 16; + } +#endif // __ARM_NEON + for (; i < num_output; i++) + { + float h_cont = hidden_state[i]; + +#if __ARM_NEON + float32x4_t _h_cont = vdupq_n_f32(h_cont); + float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG))))); + _weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc_IFOG); + _IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont); +#else + I += weight_hc_int8_IFOG[0] * descale_hc_I * h_cont; + F += weight_hc_int8_IFOG[1] * descale_hc_F * h_cont; + O += weight_hc_int8_IFOG[2] * descale_hc_O * h_cont; + G += weight_hc_int8_IFOG[3] * descale_hc_G * h_cont; +#endif // __ARM_NEON + + weight_hc_int8_IFOG += 4; + } + + float* gates_data = gates.row(q); + +#if __ARM_NEON + _IFOG = vaddq_f32(_IFOG, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _IFOG = vaddq_f32(_IFOG, _sum2); + + vst1q_f32(gates_data, _IFOG); +#else + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; +#endif // __ARM_NEON + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + unsigned short* output_data = top_blob.row(ti); + + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + + int remain_hidden_size_start = 0; +#if __ARM_NEON + int nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 4; + + const float* gates_data = gates.row(q); + + float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data); + + float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]); + float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]); + float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]); + float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]); + + float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); + float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); + + vst1q_f32(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _lstm_H); + vst1_u16(output_data + q, float2bfloat(_lstm_H)); + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _lstm_H); + } + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_ptr[q] + I * G; + float H = O * tanhf(cell2); + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = float32_to_bfloat16(H); + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = float32_to_bfloat16(H); + } + } + } + + return 0; +} +#endif // NCNN_INT8 + +int LSTM_arm::create_pipeline_bf16s(const Option& opt) +{ +#if NCNN_INT8 + if (int8_scale_term) + { + create_pipeline_int8(opt); + + ncnn::Mat tmp; + cast_float32_to_bfloat16(bias_c_data_packed, tmp, opt); + bias_c_data_packed = tmp; + + return 0; + } +#endif + + // pack IFOG + int num_directions = direction == 2 ? 2 : 1; int size = weight_data_size / num_directions / hidden_size / 4; weight_xc_data_packed.create(size, hidden_size, num_directions, 8u, 4); @@ -1004,9 +1759,20 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_bf16s_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -1019,16 +1785,38 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_bf16s_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } hidden.fill(0.f); cell.fill(0.f); - int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_bf16s_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -1082,9 +1870,20 @@ int LSTM_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; protected: +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); +#endif #if NCNN_ARM82 int create_pipeline_fp16s(const Option& opt); int forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; @@ -46,6 +49,11 @@ class LSTM_arm : public LSTM Mat weight_xc_data_packed; Mat bias_c_data_packed; Mat weight_hc_data_packed; + +#if NCNN_INT8 + Mat weight_hc_data_int8_descales_packed; + Mat weight_xc_data_int8_descales_packed; +#endif }; } // namespace ncnn diff --git a/src/layer/arm/lstm_arm_asimdhp.cpp b/src/layer/arm/lstm_arm_asimdhp.cpp index 1d3fc71cdfc..5dcaffd6c69 100644 --- a/src/layer/arm/lstm_arm_asimdhp.cpp +++ b/src/layer/arm/lstm_arm_asimdhp.cpp @@ -643,11 +643,962 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const return 0; } +#if NCNN_INT8 +static int lstm_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + + // unroll + for (int t = 0; t < T; t++) + { + // clip hidden by continuation indicator + // h_cont_{t-1} = cont_t * h_{t-1} + // h_cont_{t-1} = h_{t-1} if cont_t == 1 + // 0 otherwise + // calculate hidden + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + + int ti = reverse ? T - 1 - t : t; + + const __fp16* x = bottom_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4; + + // gate I F O G + const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q); + const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q); + const float* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q); + const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q); + + float32x4_t _descale_xc = vld1q_f32(weight_xc_int8_descales_IFOG); + float32x4_t _descale_hc = vld1q_f32(weight_hc_int8_descales_IFOG); + + float32x4_t _IFOG = vcvt_f32_f16(vld1_f16(bias_c_IFOG)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + int i = 0; + for (; i + 3 < size; i += 4) + { + float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); + + int8x16_t _weight_xc_IFOG = vld1q_s8(weight_xc_int8_IFOG); + int16x8_t _weight_xc_IFOG01 = vmovl_s8(vget_low_s8(_weight_xc_IFOG)); + int16x8_t _weight_xc_IFOG23 = vmovl_s8(vget_high_s8(_weight_xc_IFOG)); + float32x4_t _weight_xc_IFOG_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_IFOG01))); + float32x4_t _weight_xc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG01))); + float32x4_t _weight_xc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_IFOG23))); + float32x4_t _weight_xc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG23))); + _weight_xc_IFOG_0 = vmulq_f32(_weight_xc_IFOG_0, _descale_xc); + _weight_xc_IFOG_1 = vmulq_f32(_weight_xc_IFOG_1, _descale_xc); + _weight_xc_IFOG_2 = vmulq_f32(_weight_xc_IFOG_2, _descale_xc); + _weight_xc_IFOG_3 = vmulq_f32(_weight_xc_IFOG_3, _descale_xc); + + _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3); + + weight_xc_int8_IFOG += 16; + } + for (; i < size; i++) + { + __fp16 xi = x[i]; + + float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); + float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG))))); + _weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc); + _IFOG = vfmaq_f32(_IFOG, _weight_xc_IFOG, _xi); + + weight_xc_int8_IFOG += 4; + } + + i = 0; + for (; i + 3 < num_output; i += 4) + { + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); + + int8x16_t _weight_hc_IFOG = vld1q_s8(weight_hc_int8_IFOG); + int16x8_t _weight_hc_IFOG01 = vmovl_s8(vget_low_s8(_weight_hc_IFOG)); + int16x8_t _weight_hc_IFOG23 = vmovl_s8(vget_high_s8(_weight_hc_IFOG)); + float32x4_t _weight_hc_IFOG_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_IFOG01))); + float32x4_t _weight_hc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG01))); + float32x4_t _weight_hc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_IFOG23))); + float32x4_t _weight_hc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG23))); + _weight_hc_IFOG_0 = vmulq_f32(_weight_hc_IFOG_0, _descale_hc); + _weight_hc_IFOG_1 = vmulq_f32(_weight_hc_IFOG_1, _descale_hc); + _weight_hc_IFOG_2 = vmulq_f32(_weight_hc_IFOG_2, _descale_hc); + _weight_hc_IFOG_3 = vmulq_f32(_weight_hc_IFOG_3, _descale_hc); + + _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3); + + weight_hc_int8_IFOG += 16; + } + for (; i < num_output; i++) + { + float h_cont = hidden_state[i]; + + float32x4_t _h_cont = vdupq_n_f32(h_cont); + float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG))))); + _weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc); + _IFOG = vfmaq_f32(_IFOG, _weight_hc_IFOG, _h_cont); + + weight_hc_int8_IFOG += 4; + } + + float* gates_data = gates.row(q); + + _IFOG = vaddq_f32(_IFOG, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _IFOG = vaddq_f32(_IFOG, _sum2); + + vst1q_f32(gates_data, _IFOG); + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + __fp16* output_data = top_blob.row<__fp16>(ti); + + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + + int nn_hidden_size = hidden_size >> 2; + int remain_hidden_size_start = nn_hidden_size << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 4; + + const float* gates_data = gates.row(q); + + float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data); + + float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]); + float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]); + float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]); + float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]); + + float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); + float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); + + vst1q_f32(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _lstm_H); + vst1_f16(output_data + q, vcvt_f16_f32(_lstm_H)); + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _lstm_H); + } + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_ptr[q] + I * G; + float H = O * tanhf(cell2); + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + } + } + + return 0; +} + +static int lstm_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 2u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + + // unroll + for (int t = 0; t < T; t++) + { + // clip hidden by continuation indicator + // h_cont_{t-1} = cont_t * h_{t-1} + // h_cont_{t-1} = h_{t-1} if cont_t == 1 + // 0 otherwise + // calculate hidden + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + + int ti = reverse ? T - 1 - t : t; + + int nn_hidden_size = hidden_size >> 1; + int remain_hidden_size_start = nn_hidden_size << 1; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 2; + + const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4; + + // gate I F O G + const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q / 2); + const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q / 2); + const __fp16* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q / 2); + const __fp16* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q / 2); + + float16x8_t _descale_xc = vld1q_f16(weight_xc_int8_descales_IFOG); + float16x8_t _descale_hc = vld1q_f16(weight_hc_int8_descales_IFOG); + + float16x8_t _IFOG = vld1q_f16(bias_c_IFOG); + float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f); + + const __fp16* x = bottom_blob.row(ti); + + int i = 0; + for (; i + 3 < size; i += 4) + { +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v6.16b, v7.16b}, [%1], #32 \n" + "ld1 {v4.4h}, [%0], #8 \n" + "sxtl v0.8h, v6.8b \n" + "sxtl2 v1.8h, v6.16b \n" + "sxtl v2.8h, v7.8b \n" + "sxtl2 v3.8h, v7.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v1.8h, v1.8h \n" + "scvtf v2.8h, v2.8h \n" + "scvtf v3.8h, v3.8h \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v1.8h, v1.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "fmul v3.8h, v3.8h, %12.8h \n" + "fmla %2.8h, v0.8h, v4.h[0] \n" + "fmla %3.8h, v1.8h, v4.h[1] \n" + "fmla %4.8h, v2.8h, v4.h[2] \n" + "fmla %5.8h, v3.8h, v4.h[3] \n" + : "=r"(x), + "=r"(weight_xc_int8_IFOG), + "=w"(_IFOG), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(x), + "1"(weight_xc_int8_IFOG), + "2"(_IFOG), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "w"(_descale_xc) + : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _x = vld1_f16(x); + + int8x16_t _weight_xc_IFOG01 = vld1q_s8(weight_xc_int8_IFOG); + int8x16_t _weight_xc_IFOG23 = vld1q_s8(weight_xc_int8_IFOG + 16); + float16x8_t _w0 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_IFOG01))); + float16x8_t _w1 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_IFOG01))); + float16x8_t _w2 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_IFOG23))); + float16x8_t _w3 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_IFOG23))); + _w0 = vmulq_f16(_w0, _descale_xc); + _w1 = vmulq_f16(_w1, _descale_xc); + _w2 = vmulq_f16(_w2, _descale_xc); + _w3 = vmulq_f16(_w3, _descale_xc); + + _IFOG = vfmaq_lane_f16(_IFOG, _w0, _x, 0); + _sum1 = vfmaq_lane_f16(_sum1, _w1, _x, 1); + _sum2 = vfmaq_lane_f16(_sum2, _w2, _x, 2); + _sum3 = vfmaq_lane_f16(_sum3, _w3, _x, 3); + + x += 4; + weight_xc_int8_IFOG += 32; +#endif // NCNN_GNU_INLINE_ASM + } + for (; i < size; i++) + { + __fp16 xi = *x++; + + float16x8_t _xi = vdupq_n_f16(xi); + + float16x8_t _weight_xc_IFOG = vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG))); + _weight_xc_IFOG = vmulq_f16(_weight_xc_IFOG, _descale_xc); + + _IFOG = vfmaq_f16(_IFOG, _weight_xc_IFOG, _xi); + + weight_xc_int8_IFOG += 8; + } + + const float* hidden_ptr = hidden_state; + + i = 0; + for (; i + 3 < num_output; i += 4) + { +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "ld1 {v4.4s}, [%0], #16 \n" + "sxtl v0.8h, v6.8b \n" + "sxtl2 v1.8h, v6.16b \n" + "sxtl v2.8h, v7.8b \n" + "sxtl2 v3.8h, v7.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v1.8h, v1.8h \n" + "scvtf v2.8h, v2.8h \n" + "scvtf v3.8h, v3.8h \n" + "fcvtn v4.4h, v4.4s \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v1.8h, v1.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "fmul v3.8h, v3.8h, %12.8h \n" + "fmla %2.8h, v0.8h, v4.h[0] \n" + "fmla %3.8h, v1.8h, v4.h[1] \n" + "fmla %4.8h, v2.8h, v4.h[2] \n" + "fmla %5.8h, v3.8h, v4.h[3] \n" + : "=r"(hidden_ptr), + "=r"(weight_hc_int8_IFOG), + "=w"(_IFOG), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(hidden_ptr), + "1"(weight_hc_int8_IFOG), + "2"(_IFOG), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "w"(_descale_hc) + : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); + + int8x16_t _weight_hc_IFOG01 = vld1q_s8(weight_hc_int8_IFOG); + int8x16_t _weight_hc_IFOG23 = vld1q_s8(weight_hc_int8_IFOG + 16); + float16x8_t _w0 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_IFOG01))); + float16x8_t _w1 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_IFOG01))); + float16x8_t _w2 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_IFOG23))); + float16x8_t _w3 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_IFOG23))); + _w0 = vmulq_f16(_w0, _descale_hc); + _w1 = vmulq_f16(_w1, _descale_hc); + _w2 = vmulq_f16(_w2, _descale_hc); + _w3 = vmulq_f16(_w3, _descale_hc); + + _IFOG = vfmaq_lane_f16(_IFOG, _w0, _h_cont, 0); + _sum1 = vfmaq_lane_f16(_sum1, _w1, _h_cont, 1); + _sum2 = vfmaq_lane_f16(_sum2, _w2, _h_cont, 2); + _sum3 = vfmaq_lane_f16(_sum3, _w3, _h_cont, 3); + + hidden_ptr += 4; + weight_hc_int8_IFOG += 32; +#endif // NCNN_GNU_INLINE_ASM + } + for (; i < num_output; i++) + { + float h_cont = *hidden_ptr++; + + float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont); + + float16x8_t _weight_hc_IFOG = vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG))); + _weight_hc_IFOG = vmulq_f16(_weight_hc_IFOG, _descale_hc); + + _IFOG = vfmaq_f16(_IFOG, _weight_hc_IFOG, _h_cont); + + weight_hc_int8_IFOG += 8; + } + + __fp16* gates_data = gates.row<__fp16>(q); + + _IFOG = vaddq_f16(_IFOG, _sum1); + _sum2 = vaddq_f16(_sum2, _sum3); + _IFOG = vaddq_f16(_IFOG, _sum2); + + vst1q_f16(gates_data, _IFOG); + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4; + + // gate I F O G + const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q / 2 + q % 2); + const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q / 2 + q % 2); + const __fp16* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q / 2 + q % 2); + const __fp16* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q / 2 + q % 2); + + float16x4_t _descale_xc = vld1_f16(weight_xc_int8_descales_IFOG); + float16x4_t _descale_hc = vld1_f16(weight_hc_int8_descales_IFOG); + float16x8_t _descale_xcxc = vcombine_f16(_descale_xc, _descale_xc); + float16x8_t _descale_hchc = vcombine_f16(_descale_hc, _descale_hc); + + float16x4_t _IFOG = vld1_f16(bias_c_IFOG); + float16x4_t _sum1 = vdup_n_f16((__fp16)0.f); + float16x4_t _sum2 = vdup_n_f16((__fp16)0.f); + float16x4_t _sum3 = vdup_n_f16((__fp16)0.f); + + const __fp16* x = bottom_blob.row(ti); + + int i = 0; + for (; i + 3 < size; i += 4) + { +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v5.16b}, [%1], #16 \n" + "ld1 {v4.4h}, [%0], #8 \n" + "sxtl v0.8h, v5.8b \n" + "sxtl2 v2.8h, v5.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v2.8h, v2.8h \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "mov v1.d[0], v0.d[1] \n" + "mov v3.d[0], v2.d[1] \n" + "fmla %2.4h, v0.4h, v4.h[0] \n" + "fmla %3.4h, v1.4h, v4.h[1] \n" + "fmla %4.4h, v2.4h, v4.h[2] \n" + "fmla %5.4h, v3.4h, v4.h[3] \n" + : "=r"(x), + "=r"(weight_xc_int8_IFOG), + "=w"(_IFOG), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(x), + "1"(weight_xc_int8_IFOG), + "2"(_IFOG), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "w"(_descale_xcxc) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _x = vld1_f16(x); + + int8x16_t _weight_xc_IFOG = vld1q_s8(weight_xc_int8_IFOG); + float16x8_t _w01 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_IFOG))); + float16x8_t _w23 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_IFOG))); + _w01 = vmulq_f16(_w01, _descale_xcxc); + _w23 = vmulq_f16(_w23, _descale_xcxc); + + _IFOG = vfma_lane_f16(_IFOG, vget_low_f16(_w01), _x, 0); + _sum1 = vfma_lane_f16(_sum1, vget_high_f16(_w01), _x, 1); + _sum2 = vfma_lane_f16(_sum2, vget_low_f16(_w23), _x, 2); + _sum3 = vfma_lane_f16(_sum3, vget_high_f16(_w23), _x, 3); + + x += 4; + weight_xc_int8_IFOG += 16; +#endif // NCNN_GNU_INLINE_ASM + } + for (; i < size; i++) + { + __fp16 xi = *x++; + + float16x4_t _xi = vdup_n_f16(xi); + + float16x4_t _weight_xc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG)))); + _weight_xc_IFOG = vmul_f16(_weight_xc_IFOG, _descale_xc); + + _IFOG = vfma_f16(_IFOG, _weight_xc_IFOG, _xi); + + weight_xc_int8_IFOG += 4; + } + + const float* hidden_ptr = hidden_state; + + i = 0; + for (; i + 3 < num_output; i += 4) + { +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v5.16b}, [%1], #16 \n" + "ld1 {v4.4s}, [%0], #16 \n" + "sxtl v0.8h, v5.8b \n" + "sxtl2 v2.8h, v5.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v2.8h, v2.8h \n" + "fcvtn v4.4h, v4.4s \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "mov v1.d[0], v0.d[1] \n" + "mov v3.d[0], v2.d[1] \n" + "fmla %2.4h, v0.4h, v4.h[0] \n" + "fmla %3.4h, v1.4h, v4.h[1] \n" + "fmla %4.4h, v2.4h, v4.h[2] \n" + "fmla %5.4h, v3.4h, v4.h[3] \n" + : "=r"(hidden_ptr), + "=r"(weight_hc_int8_IFOG), + "=w"(_IFOG), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(hidden_ptr), + "1"(weight_hc_int8_IFOG), + "2"(_IFOG), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "w"(_descale_hchc) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); + + int8x16_t _weight_hc_IFOG = vld1q_s8(weight_hc_int8_IFOG); + float16x8_t _w01 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_IFOG))); + float16x8_t _w23 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_IFOG))); + _w01 = vmulq_f16(_w01, _descale_hchc); + _w23 = vmulq_f16(_w23, _descale_hchc); + + _IFOG = vfma_lane_f16(_IFOG, vget_low_f16(_w01), _h_cont, 0); + _sum1 = vfma_lane_f16(_sum1, vget_high_f16(_w01), _h_cont, 1); + _sum2 = vfma_lane_f16(_sum2, vget_low_f16(_w23), _h_cont, 2); + _sum3 = vfma_lane_f16(_sum3, vget_high_f16(_w23), _h_cont, 3); + + hidden_ptr += 4; + weight_hc_int8_IFOG += 16; +#endif // NCNN_GNU_INLINE_ASM + } + for (; i < num_output; i++) + { + float h_cont = *hidden_ptr++; + + float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); + + float16x4_t _weight_hc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG)))); + _weight_hc_IFOG = vmul_f16(_weight_hc_IFOG, _descale_hc); + + _IFOG = vfma_f16(_IFOG, _weight_hc_IFOG, _h_cont); + + weight_hc_int8_IFOG += 4; + } + + __fp16* gates_data = gates.row<__fp16>(q); + + _IFOG = vadd_f16(_IFOG, _sum1); + _sum2 = vadd_f16(_sum2, _sum3); + _IFOG = vadd_f16(_IFOG, _sum2); + + vst1_f16(gates_data, _IFOG); + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + __fp16* output_data = top_blob.row<__fp16>(ti); + + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + + nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 4; + + const __fp16* gates_data = gates.row(q); + + float16x4x4_t _IFOG_4x4 = vld4_f16(gates_data); + + float32x4_t _lstm_I = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[0])); + float32x4_t _lstm_F = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[1])); + float32x4_t _lstm_O = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[2])); + float32x4_t _lstm_G = tanh_ps(vcvt_f32_f16(_IFOG_4x4.val[3])); + + float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); + float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); + + vst1q_f32(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _lstm_H); + vst1_f16(output_data + q, vcvt_f16_f32(_lstm_H)); + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _lstm_H); + } + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const __fp16* gates_data = gates.row(q); + + float I = (float)gates_data[0]; + float F = (float)gates_data[1]; + float O = (float)gates_data[2]; + float G = (float)gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_ptr[q] + I * G; + float H = O * tanhf(cell2); + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + } + } + + return 0; +} +#endif // NCNN_INT8 + int LSTM_arm::create_pipeline_fp16s(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + // pack IFOG + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / hidden_size / 4; + + if (opt.use_fp16_arithmetic) + { + weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 8u, 8); + bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); + weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 8u, 8); + weight_xc_data_int8_descales_packed.create(8, hidden_size / 2 + hidden_size % 2, num_directions, 2u, 1); + weight_hc_data_int8_descales_packed.create(8, hidden_size / 2 + hidden_size % 2, num_directions, 2u, 1); + } + else + { + weight_xc_data_packed.create(size, hidden_size, num_directions, 4u, 4); + bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); + weight_hc_data_packed.create(num_output, hidden_size, num_directions, 4u, 4); + weight_xc_data_int8_descales_packed.create(4, hidden_size, num_directions, 4u, 1); + weight_hc_data_int8_descales_packed.create(4, hidden_size, num_directions, 4u, 1); + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc = weight_xc_data.channel(dr); + const Mat bias_c = bias_c_data.channel(dr); + const Mat weight_hc = weight_hc_data.channel(dr); + const float* weight_xc_int8_scales = weight_xc_data_int8_scales.row(dr); + const float* weight_hc_int8_scales = weight_hc_data_int8_scales.row(dr); + + Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); + Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr); + Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); + Mat weight_xc_data_int8_descales_packed_dr = weight_xc_data_int8_descales_packed.channel(dr); + Mat weight_hc_data_int8_descales_packed_dr = weight_hc_data_int8_descales_packed.channel(dr); + + const float* bias_c_I = bias_c.row(0); + const float* bias_c_F = bias_c.row(1); + const float* bias_c_O = bias_c.row(2); + const float* bias_c_G = bias_c.row(3); + + __fp16* bias_c_IFOG = bias_c_data_packed_dr.row<__fp16>(0); + + int q = 0; + if (opt.use_fp16_arithmetic) + { + for (; q + 1 < hidden_size; q += 2) + { + bias_c_IFOG[0] = (__fp16)bias_c_I[q]; + bias_c_IFOG[1] = (__fp16)bias_c_F[q]; + bias_c_IFOG[2] = (__fp16)bias_c_O[q]; + bias_c_IFOG[3] = (__fp16)bias_c_G[q]; + bias_c_IFOG[4] = (__fp16)bias_c_I[q + 1]; + bias_c_IFOG[5] = (__fp16)bias_c_F[q + 1]; + bias_c_IFOG[6] = (__fp16)bias_c_O[q + 1]; + bias_c_IFOG[7] = (__fp16)bias_c_G[q + 1]; + + bias_c_IFOG += 8; + + const signed char* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const signed char* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const signed char* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const signed char* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + const signed char* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1); + const signed char* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1); + const signed char* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1); + const signed char* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1); + + const signed char* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const signed char* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const signed char* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const signed char* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + const signed char* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1); + const signed char* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1); + const signed char* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1); + const signed char* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1); + + signed char* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2); + signed char* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2); + __fp16* weight_xc_int8_descales_ptr = weight_xc_data_int8_descales_packed_dr.row<__fp16>(q / 2); + __fp16* weight_hc_int8_descales_ptr = weight_hc_data_int8_descales_packed_dr.row<__fp16>(q / 2); + + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = weight_xc_I[i]; + weight_xc_IFOG[1] = weight_xc_F[i]; + weight_xc_IFOG[2] = weight_xc_O[i]; + weight_xc_IFOG[3] = weight_xc_G[i]; + weight_xc_IFOG[4] = weight_xc_I_1[i]; + weight_xc_IFOG[5] = weight_xc_F_1[i]; + weight_xc_IFOG[6] = weight_xc_O_1[i]; + weight_xc_IFOG[7] = weight_xc_G_1[i]; + + weight_xc_IFOG += 8; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = weight_hc_I[i]; + weight_hc_IFOG[1] = weight_hc_F[i]; + weight_hc_IFOG[2] = weight_hc_O[i]; + weight_hc_IFOG[3] = weight_hc_G[i]; + weight_hc_IFOG[4] = weight_hc_I_1[i]; + weight_hc_IFOG[5] = weight_hc_F_1[i]; + weight_hc_IFOG[6] = weight_hc_O_1[i]; + weight_hc_IFOG[7] = weight_hc_G_1[i]; + + weight_hc_IFOG += 8; + } + + weight_xc_int8_descales_ptr[0] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 0 + q]); + weight_xc_int8_descales_ptr[1] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 1 + q]); + weight_xc_int8_descales_ptr[2] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 2 + q]); + weight_xc_int8_descales_ptr[3] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 3 + q]); + weight_xc_int8_descales_ptr[4] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 0 + q + 1]); + weight_xc_int8_descales_ptr[5] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 1 + q + 1]); + weight_xc_int8_descales_ptr[6] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 2 + q + 1]); + weight_xc_int8_descales_ptr[7] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 3 + q + 1]); + + weight_hc_int8_descales_ptr[0] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 0 + q]); + weight_hc_int8_descales_ptr[1] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 1 + q]); + weight_hc_int8_descales_ptr[2] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 2 + q]); + weight_hc_int8_descales_ptr[3] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 3 + q]); + weight_hc_int8_descales_ptr[4] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 0 + q + 1]); + weight_hc_int8_descales_ptr[5] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 1 + q + 1]); + weight_hc_int8_descales_ptr[6] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 2 + q + 1]); + weight_hc_int8_descales_ptr[7] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 3 + q + 1]); + } + } + for (; q < hidden_size; q++) + { + bias_c_IFOG[0] = (__fp16)bias_c_I[q]; + bias_c_IFOG[1] = (__fp16)bias_c_F[q]; + bias_c_IFOG[2] = (__fp16)bias_c_O[q]; + bias_c_IFOG[3] = (__fp16)bias_c_G[q]; + + bias_c_IFOG += 4; + + const signed char* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const signed char* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const signed char* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const signed char* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + + const signed char* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const signed char* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const signed char* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const signed char* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + + const int qq = opt.use_fp16_arithmetic ? q / 2 + q % 2 : q; + signed char* weight_xc_IFOG = weight_xc_data_packed_dr.row(qq); + signed char* weight_hc_IFOG = weight_hc_data_packed_dr.row(qq); + + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = weight_xc_I[i]; + weight_xc_IFOG[1] = weight_xc_F[i]; + weight_xc_IFOG[2] = weight_xc_O[i]; + weight_xc_IFOG[3] = weight_xc_G[i]; + + weight_xc_IFOG += 4; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = weight_hc_I[i]; + weight_hc_IFOG[1] = weight_hc_F[i]; + weight_hc_IFOG[2] = weight_hc_O[i]; + weight_hc_IFOG[3] = weight_hc_G[i]; + + weight_hc_IFOG += 4; + } + + if (opt.use_fp16_arithmetic) + { + __fp16* weight_xc_int8_descales_ptr = weight_xc_data_int8_descales_packed_dr.row<__fp16>(qq); + __fp16* weight_hc_int8_descales_ptr = weight_hc_data_int8_descales_packed_dr.row<__fp16>(qq); + + weight_xc_int8_descales_ptr[0] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 0 + q]); + weight_xc_int8_descales_ptr[1] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 1 + q]); + weight_xc_int8_descales_ptr[2] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 2 + q]); + weight_xc_int8_descales_ptr[3] = (__fp16)(1.f / weight_xc_int8_scales[hidden_size * 3 + q]); + + weight_hc_int8_descales_ptr[0] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 0 + q]); + weight_hc_int8_descales_ptr[1] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 1 + q]); + weight_hc_int8_descales_ptr[2] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 2 + q]); + weight_hc_int8_descales_ptr[3] = (__fp16)(1.f / weight_hc_int8_scales[hidden_size * 3 + q]); + } + else + { + float* weight_xc_int8_descales_ptr = weight_xc_data_int8_descales_packed_dr.row(qq); + float* weight_hc_int8_descales_ptr = weight_hc_data_int8_descales_packed_dr.row(qq); + + weight_xc_int8_descales_ptr[0] = 1.f / weight_xc_int8_scales[hidden_size * 0 + q]; + weight_xc_int8_descales_ptr[1] = 1.f / weight_xc_int8_scales[hidden_size * 1 + q]; + weight_xc_int8_descales_ptr[2] = 1.f / weight_xc_int8_scales[hidden_size * 2 + q]; + weight_xc_int8_descales_ptr[3] = 1.f / weight_xc_int8_scales[hidden_size * 3 + q]; + + weight_hc_int8_descales_ptr[0] = 1.f / weight_hc_int8_scales[hidden_size * 0 + q]; + weight_hc_int8_descales_ptr[1] = 1.f / weight_hc_int8_scales[hidden_size * 1 + q]; + weight_hc_int8_descales_ptr[2] = 1.f / weight_hc_int8_scales[hidden_size * 2 + q]; + weight_hc_int8_descales_ptr[3] = 1.f / weight_hc_int8_scales[hidden_size * 3 + q]; + } + } + } + + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + weight_xc_data_int8_scales.release(); + weight_hc_data_int8_scales.release(); + } + + return 0; + } +#endif + // pack IFOG - int num_directions = direction == 2 ? 2 : 1; - int size = weight_data_size / num_directions / hidden_size / 4; + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / hidden_size / 4; if (opt.use_fp16_arithmetic) { @@ -869,9 +1820,20 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_fp16s_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -884,16 +1846,38 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_fp16s_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } hidden.fill(0.f); cell.fill(0.f); - int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_fp16s_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -947,9 +1931,20 @@ int LSTM_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector