From 9275f170e5793338a8a829adc40ad183ba550947 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 18 Apr 2024 19:54:55 +0800 Subject: [PATCH] rnn int8 kernel --- src/layer/arm/gru_arm.cpp | 2 +- src/layer/arm/gru_arm_asimdhp.cpp | 6 +- src/layer/arm/rnn_arm.cpp | 755 +++++++++++++++++++++++++-- src/layer/arm/rnn_arm.h | 8 + src/layer/arm/rnn_arm_asimdhp.cpp | 835 ++++++++++++++++++++++++++++-- src/layer/rnn.cpp | 189 +++++-- tests/test_rnn.cpp | 429 ++++++++++++--- 7 files changed, 2013 insertions(+), 211 deletions(-) diff --git a/src/layer/arm/gru_arm.cpp b/src/layer/arm/gru_arm.cpp index 8184b6095da..db26e9babfd 100644 --- a/src/layer/arm/gru_arm.cpp +++ b/src/layer/arm/gru_arm.cpp @@ -2376,7 +2376,7 @@ int GRU_arm::create_pipeline_bf16s(const Option& opt) create_pipeline_int8(opt); ncnn::Mat tmp; - ncnn::cast_float32_to_bfloat16(bias_c_data_packed, tmp, opt); + cast_float32_to_bfloat16(bias_c_data_packed, tmp, opt); bias_c_data_packed = tmp; return 0; diff --git a/src/layer/arm/gru_arm_asimdhp.cpp b/src/layer/arm/gru_arm_asimdhp.cpp index cbb69870ee4..0d9317daf2c 100644 --- a/src/layer/arm/gru_arm_asimdhp.cpp +++ b/src/layer/arm/gru_arm_asimdhp.cpp @@ -1594,21 +1594,21 @@ int GRU_arm::create_pipeline_fp16s(const Option& opt) { ncnn::Mat tmp; - ncnn::cast_float32_to_float16(bias_c_data_packed, tmp, opt); + cast_float32_to_float16(bias_c_data_packed, tmp, opt); bias_c_data_packed = tmp; } if (opt.use_fp16_arithmetic) { ncnn::Mat tmp; - ncnn::cast_float32_to_float16(weight_xc_data_int8_descales_packed, tmp, opt); + cast_float32_to_float16(weight_xc_data_int8_descales_packed, tmp, opt); weight_xc_data_int8_descales_packed = tmp; } if (opt.use_fp16_arithmetic) { ncnn::Mat tmp; - ncnn::cast_float32_to_float16(weight_hc_data_int8_descales_packed, tmp, opt); + cast_float32_to_float16(weight_hc_data_int8_descales_packed, tmp, opt); weight_hc_data_int8_descales_packed = tmp; } diff --git a/src/layer/arm/rnn_arm.cpp b/src/layer/arm/rnn_arm.cpp index 293322b8488..d2e3c3e8779 100644 --- a/src/layer/arm/rnn_arm.cpp +++ b/src/layer/arm/rnn_arm.cpp @@ -54,12 +54,23 @@ int RNN_arm::create_pipeline(const Option& opt) } #endif - int num_directions = direction == 2 ? 2 : 1; - int size = weight_data_size / num_directions / num_output; +#if NCNN_INT8 + if (int8_scale_term) + { + return create_pipeline_int8(opt); + } +#endif + + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / num_output; #if __ARM_NEON weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions); weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions); +#else + weight_xc_data_packed.create(size, num_output, num_directions); + weight_hc_data_packed.create(num_output, num_output, num_directions); +#endif #pragma omp parallel for num_threads(opt.num_threads) for (int dr = 0; dr < num_directions; dr++) @@ -132,10 +143,6 @@ int RNN_arm::create_pipeline(const Option& opt) } } } -#else - weight_xc_data_packed = weight_xc_data; - weight_hc_data_packed = weight_hc_data; -#endif bias_c_data_packed = bias_c_data; @@ -317,6 +324,332 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we return 0; } +#if NCNN_INT8 +static int rnn_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + + // num_output + Mat gates(num_output, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + const float* x = bottom_blob.row(ti); + + int remain_num_output_start = 0; +#if __ARM_NEON + int nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4); + + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4); + + float32x4_t _descale_xc = vld1q_f32(weight_xc_int8_descales_ptr); + float32x4_t _descale_hc = vld1q_f32(weight_hc_int8_descales_ptr); + + float32x4_t _rnn_H = vld1q_f32((const float*)bias_c + q); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + int i = 0; + for (; i + 3 < size; i += 4) + { + float32x4_t _x = vld1q_f32(x + i); + + int8x16_t _weight_xc = vld1q_s8(weight_xc_int8_ptr); + int16x8_t _weight_xc_01 = vmovl_s8(vget_low_s8(_weight_xc)); + int16x8_t _weight_xc_23 = vmovl_s8(vget_high_s8(_weight_xc)); + float32x4_t _weight_xc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_01))), _descale_xc); + float32x4_t _weight_xc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_01))), _descale_xc); + float32x4_t _weight_xc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_23))), _descale_xc); + float32x4_t _weight_xc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_23))), _descale_xc); + +#if __aarch64__ + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc_0, _x, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); +#else + _rnn_H = vmlaq_lane_f32(_rnn_H, _weight_xc_0, vget_low_f32(_x), 0); + _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_1, vget_low_f32(_x), 1); + _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_2, vget_high_f32(_x), 0); + _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_3, vget_high_f32(_x), 1); +#endif + + weight_xc_int8_ptr += 16; + } + for (; i < size; i++) + { + float32x4_t _x = vdupq_n_f32(x[i]); + float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc); + _rnn_H = vmlaq_f32(_rnn_H, _weight_xc, _x); + + weight_xc_int8_ptr += 4; + } + + i = 0; + for (; i + 3 < num_output; i += 4) + { + float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i); + + int8x16_t _weight_hc = vld1q_s8(weight_hc_int8_ptr); + int16x8_t _weight_hc_01 = vmovl_s8(vget_low_s8(_weight_hc)); + int16x8_t _weight_hc_23 = vmovl_s8(vget_high_s8(_weight_hc)); + float32x4_t _weight_hc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_01))), _descale_hc); + float32x4_t _weight_hc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_01))), _descale_hc); + float32x4_t _weight_hc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_23))), _descale_hc); + float32x4_t _weight_hc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_23))), _descale_hc); + +#if __aarch64__ + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc_0, _hidden_state, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); +#else + _rnn_H = vmlaq_lane_f32(_rnn_H, _weight_hc_0, vget_low_f32(_hidden_state), 0); + _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_1, vget_low_f32(_hidden_state), 1); + _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_2, vget_high_f32(_hidden_state), 0); + _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_3, vget_high_f32(_hidden_state), 1); +#endif + + weight_hc_int8_ptr += 16; + } + for (; i < num_output; i++) + { + float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); + float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc); + _rnn_H = vmlaq_f32(_rnn_H, _weight_hc, _hidden_state); + + weight_hc_int8_ptr += 4; + } + + _rnn_H = vaddq_f32(_rnn_H, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _rnn_H = vaddq_f32(_rnn_H, _sum2); + + _rnn_H = tanh_ps(_rnn_H); + + vst1q_f32((float*)gates + q, _rnn_H); + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { +#if __ARM_NEON + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4 + q % 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4 + q % 4); + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4 + q % 4); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4 + q % 4); +#else + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q); + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q); +#endif // __ARM_NEON + + const float descale_xc = weight_xc_int8_descales_ptr[0]; + const float descale_hc = weight_hc_int8_descales_ptr[0]; + + float H = bias_c[q]; + + for (int i = 0; i < size; i++) + { + H += weight_xc_int8_ptr[i] * descale_xc * x[i]; + } + + for (int i = 0; i < num_output; i++) + { + H += weight_hc_int8_ptr[i] * descale_hc * hidden_state[i]; + } + + H = tanhf(H); + + gates[q] = H; + } + + float* output_data = top_blob.row(ti); + + float* hidden_ptr = hidden_state; + +#if __ARM_NEON + nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + float32x4_t _rnn_H = vld1q_f32((float*)gates + q); + + vst1q_f32(hidden_ptr + q, _rnn_H); + vst1q_f32(output_data + q, _rnn_H); + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + float H = gates[q]; + + hidden_ptr[q] = H; + output_data[q] = H; + } + } + + return 0; +} + +int RNN_arm::create_pipeline_int8(const Option& opt) +{ + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / num_output; + +#if __ARM_NEON + weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions, 1u, 1); + weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions, 1u, 1); + weight_xc_data_int8_descales_packed.create(4, num_output / 4 + num_output % 4, num_directions); + weight_hc_data_int8_descales_packed.create(4, num_output / 4 + num_output % 4, num_directions); +#else + weight_xc_data_packed.create(size, num_output, num_directions); + weight_hc_data_packed.create(num_output, num_output, num_directions); + weight_xc_data_int8_descales_packed.create(1, num_output, num_directions); + weight_hc_data_int8_descales_packed.create(1, num_output, num_directions); +#endif + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc = weight_xc_data.channel(dr); + const Mat weight_hc = weight_hc_data.channel(dr); + const float* weight_xc_int8_scales = weight_xc_data_int8_scales.row(dr); + const float* weight_hc_int8_scales = weight_hc_data_int8_scales.row(dr); + + Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); + Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); + Mat weight_xc_data_int8_descales_packed_dr = weight_xc_data_int8_descales_packed.channel(dr); + Mat weight_hc_data_int8_descales_packed_dr = weight_hc_data_int8_descales_packed.channel(dr); + + int q = 0; +#if __ARM_NEON + for (; q + 3 < num_output; q += 4) + { + const signed char* weight_xc_0 = weight_xc.row(q); + const signed char* weight_xc_1 = weight_xc.row(q + 1); + const signed char* weight_xc_2 = weight_xc.row(q + 2); + const signed char* weight_xc_3 = weight_xc.row(q + 3); + + const signed char* weight_hc_0 = weight_hc.row(q); + const signed char* weight_hc_1 = weight_hc.row(q + 1); + const signed char* weight_hc_2 = weight_hc.row(q + 2); + const signed char* weight_hc_3 = weight_hc.row(q + 3); + + signed char* weight_xc_ptr = weight_xc_data_packed_dr.row(q / 4); + signed char* weight_hc_ptr = weight_hc_data_packed_dr.row(q / 4); + float* weight_xc_int8_descales_ptr = weight_xc_data_int8_descales_packed_dr.row(q / 4); + float* weight_hc_int8_descales_ptr = weight_hc_data_int8_descales_packed_dr.row(q / 4); + + for (int i = 0; i < size; i++) + { + weight_xc_ptr[0] = weight_xc_0[i]; + weight_xc_ptr[1] = weight_xc_1[i]; + weight_xc_ptr[2] = weight_xc_2[i]; + weight_xc_ptr[3] = weight_xc_3[i]; + + weight_xc_ptr += 4; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_ptr[0] = weight_hc_0[i]; + weight_hc_ptr[1] = weight_hc_1[i]; + weight_hc_ptr[2] = weight_hc_2[i]; + weight_hc_ptr[3] = weight_hc_3[i]; + + weight_hc_ptr += 4; + } + + float32x4_t _xc = vld1q_f32(weight_xc_int8_scales + q); + float32x4_t _hc = vld1q_f32(weight_hc_int8_scales + q); + +#if __aarch64__ + float32x4_t _one = vdupq_n_f32(1.f); + float32x4_t _reciprocal_xc = vdivq_f32(_one, _xc); + float32x4_t _reciprocal_hc = vdivq_f32(_one, _hc); +#else + float32x4_t _reciprocal_xc = vrecpeq_f32(_xc); + _reciprocal_xc = vmulq_f32(vrecpsq_f32(_xc, _reciprocal_xc_R), _reciprocal_xc_R); + float32x4_t _reciprocal_hc = vrecpeq_f32(_hc); + _reciprocal_hc = vmulq_f32(vrecpsq_f32(_hc, _reciprocal_hc_R), _reciprocal_hc_R); +#endif + + vst1q_f32(weight_xc_int8_descales_ptr, _reciprocal_xc); + vst1q_f32(weight_hc_int8_descales_ptr, _reciprocal_hc); + } +#endif // __ARM_NEON + for (; q < num_output; q++) + { + const float* weight_xc_0 = weight_xc.row(q); + const float* weight_hc_0 = weight_hc.row(q); + +#if __ARM_NEON + float* weight_xc_ptr = weight_xc_data_packed_dr.row(q / 4 + q % 4); + float* weight_hc_ptr = weight_hc_data_packed_dr.row(q / 4 + q % 4); + float* weight_xc_int8_descales_ptr = weight_xc_data_int8_descales_packed_dr.row(q / 4 + q % 4); + float* weight_hc_int8_descales_ptr = weight_hc_data_int8_descales_packed_dr.row(q / 4 + q % 4); +#else + float* weight_xc_ptr = weight_xc_data_packed_dr.row(q); + float* weight_hc_ptr = weight_hc_data_packed_dr.row(q); + float* weight_xc_int8_descales_ptr = weight_xc_data_int8_descales_packed_dr.row(q); + float* weight_hc_int8_descales_ptr = weight_hc_data_int8_descales_packed_dr.row(q); +#endif // __ARM_NEON + + for (int i = 0; i < size; i++) + { + weight_xc_ptr[i] = weight_xc_0[i]; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_ptr[i] = weight_hc_0[i]; + } + + weight_xc_int8_descales_ptr[0] = 1.f / weight_xc_int8_scales[q]; + weight_hc_int8_descales_ptr[0] = 1.f / weight_hc_int8_scales[q]; + } + } + + bias_c_data_packed = bias_c_data; + + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + weight_xc_data_int8_scales.release(); + weight_hc_data_int8_scales.release(); + } + + return 0; +} +#endif // NCNN_INT8 + int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int elembits = bottom_blob.elembits(); @@ -353,9 +686,20 @@ int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -368,15 +712,37 @@ int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c if (top_blob_reverse.empty()) return -100; - int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); - int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -438,9 +804,20 @@ int RNN_arm::forward(const std::vector& bottom_blobs, std::vector& top // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -454,14 +831,36 @@ int RNN_arm::forward(const std::vector& bottom_blobs, std::vector& top return -100; Mat hidden0 = hidden.row_range(0, 1); - int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden0, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); - int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), hidden1, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -652,8 +1051,214 @@ static int rnn_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M return 0; } +#if NCNN_INT8 +static int rnn_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + + // num_output + Mat gates(num_output, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + const unsigned short* x = bottom_blob.row(ti); + + int remain_num_output_start = 0; +#if __ARM_NEON + int nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4); + + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4); + + float32x4_t _descale_xc = vld1q_f32(weight_xc_int8_descales_ptr); + float32x4_t _descale_hc = vld1q_f32(weight_hc_int8_descales_ptr); + + float32x4_t _rnn_H = bfloat2float(vld1_u16((const unsigned short*)bias_c + q)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + int i = 0; + for (; i + 3 < size; i += 4) + { + float32x4_t _x = bfloat2float(vld1_u16(x + i)); + + int8x16_t _weight_xc = vld1q_s8(weight_xc_int8_ptr); + int16x8_t _weight_xc_01 = vmovl_s8(vget_low_s8(_weight_xc)); + int16x8_t _weight_xc_23 = vmovl_s8(vget_high_s8(_weight_xc)); + float32x4_t _weight_xc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_01))), _descale_xc); + float32x4_t _weight_xc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_01))), _descale_xc); + float32x4_t _weight_xc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_23))), _descale_xc); + float32x4_t _weight_xc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_23))), _descale_xc); + +#if __aarch64__ + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc_0, _x, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); +#else + _rnn_H = vmlaq_lane_f32(_rnn_H, _weight_xc_0, vget_low_f32(_x), 0); + _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_1, vget_low_f32(_x), 1); + _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_2, vget_high_f32(_x), 0); + _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_3, vget_high_f32(_x), 1); +#endif + + weight_xc_int8_ptr += 16; + } + for (; i < size; i++) + { + float32x4_t _x = bfloat2float(vdup_n_u16(x[i])); + float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc); + _rnn_H = vmlaq_f32(_rnn_H, _weight_xc, _x); + + weight_xc_int8_ptr += 4; + } + + i = 0; + for (; i + 3 < num_output; i += 4) + { + float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i); + + int8x16_t _weight_hc = vld1q_s8(weight_hc_int8_ptr); + int16x8_t _weight_hc_01 = vmovl_s8(vget_low_s8(_weight_hc)); + int16x8_t _weight_hc_23 = vmovl_s8(vget_high_s8(_weight_hc)); + float32x4_t _weight_hc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_01))), _descale_hc); + float32x4_t _weight_hc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_01))), _descale_hc); + float32x4_t _weight_hc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_23))), _descale_hc); + float32x4_t _weight_hc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_23))), _descale_hc); + +#if __aarch64__ + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc_0, _hidden_state, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); +#else + _rnn_H = vmlaq_lane_f32(_rnn_H, _weight_hc_0, vget_low_f32(_hidden_state), 0); + _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_1, vget_low_f32(_hidden_state), 1); + _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_2, vget_high_f32(_hidden_state), 0); + _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_3, vget_high_f32(_hidden_state), 1); +#endif + + weight_hc_int8_ptr += 16; + } + for (; i < num_output; i++) + { + float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); + float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc); + _rnn_H = vmlaq_f32(_rnn_H, _weight_hc, _hidden_state); + + weight_hc_int8_ptr += 4; + } + + _rnn_H = vaddq_f32(_rnn_H, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _rnn_H = vaddq_f32(_rnn_H, _sum2); + + _rnn_H = tanh_ps(_rnn_H); + + vst1q_f32((float*)gates + q, _rnn_H); + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { +#if __ARM_NEON + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4 + q % 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4 + q % 4); + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4 + q % 4); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4 + q % 4); +#else + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q); + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q); +#endif // __ARM_NEON + + const float descale_xc = weight_xc_int8_descales_ptr[0]; + const float descale_hc = weight_hc_int8_descales_ptr[0]; + + float H = bfloat16_to_float32(((const unsigned short*)bias_c)[q]); + + for (int i = 0; i < size; i++) + { + H += weight_xc_int8_ptr[i] * descale_xc * bfloat16_to_float32(x[i]); + } + + for (int i = 0; i < num_output; i++) + { + H += weight_hc_int8_ptr[i] * descale_hc * hidden_state[i]; + } + + H = tanhf(H); + + gates[q] = H; + } + + unsigned short* output_data = top_blob.row(ti); + + float* hidden_ptr = hidden_state; + +#if __ARM_NEON + nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + float32x4_t _rnn_H = vld1q_f32((float*)gates + q); + + vst1q_f32(hidden_ptr + q, _rnn_H); + vst1_u16(output_data + q, float2bfloat(_rnn_H)); + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + float H = gates[q]; + + hidden_ptr[q] = H; + output_data[q] = float32_to_bfloat16(H); + } + } + + return 0; +} +#endif // NCNN_INT8 + int RNN_arm::create_pipeline_bf16s(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + create_pipeline_int8(opt); + + ncnn::Mat tmp; + cast_float32_to_bfloat16(bias_c_data_packed, tmp, opt); + bias_c_data_packed = tmp; + + return 0; + } +#endif + int num_directions = direction == 2 ? 2 : 1; int size = weight_data_size / num_directions / num_output; @@ -768,9 +1373,20 @@ int RNN_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_bf16s_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -783,15 +1399,37 @@ int RNN_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = rnn_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_bf16s_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.f); - int ret1 = rnn_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_bf16s_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -838,9 +1476,20 @@ int RNN_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; protected: +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); +#endif #if NCNN_ARM82 int create_pipeline_fp16s(const Option& opt); int forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; @@ -46,6 +49,11 @@ class RNN_arm : public RNN Mat weight_xc_data_packed; Mat bias_c_data_packed; Mat weight_hc_data_packed; + +#if NCNN_INT8 + Mat weight_hc_data_int8_descales_packed; + Mat weight_xc_data_int8_descales_packed; +#endif }; } // namespace ncnn diff --git a/src/layer/arm/rnn_arm_asimdhp.cpp b/src/layer/arm/rnn_arm_asimdhp.cpp index 93b009151c5..aef51ec608e 100644 --- a/src/layer/arm/rnn_arm_asimdhp.cpp +++ b/src/layer/arm/rnn_arm_asimdhp.cpp @@ -380,8 +380,639 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const return 0; } +#if NCNN_INT8 +static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + + // num_output + Mat gates(num_output, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + const __fp16* x = bottom_blob.row(ti); + + int nn_num_output = num_output >> 2; + int remain_num_output_start = nn_num_output << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4); + + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4); + + float32x4_t _descale_xc = vld1q_f32(weight_xc_int8_descales_ptr); + float32x4_t _descale_hc = vld1q_f32(weight_hc_int8_descales_ptr); + + float32x4_t _rnn_H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + int i = 0; + for (; i + 3 < size; i += 4) + { + float32x4_t _x = vcvt_f32_f16(vld1_f16(x + i)); + + int8x16_t _weight_xc = vld1q_s8(weight_xc_int8_ptr); + int16x8_t _weight_xc_01 = vmovl_s8(vget_low_s8(_weight_xc)); + int16x8_t _weight_xc_23 = vmovl_s8(vget_high_s8(_weight_xc)); + float32x4_t _weight_xc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_01))), _descale_xc); + float32x4_t _weight_xc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_01))), _descale_xc); + float32x4_t _weight_xc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_23))), _descale_xc); + float32x4_t _weight_xc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_23))), _descale_xc); + + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc_0, _x, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); + + weight_xc_int8_ptr += 16; + } + for (; i < size; i++) + { + float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i])); + float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc); + _rnn_H = vfmaq_f32(_rnn_H, _weight_xc, _x); + + weight_xc_int8_ptr += 4; + } + + i = 0; + for (; i + 3 < num_output; i += 4) + { + float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i); + + int8x16_t _weight_hc = vld1q_s8(weight_hc_int8_ptr); + int16x8_t _weight_hc_01 = vmovl_s8(vget_low_s8(_weight_hc)); + int16x8_t _weight_hc_23 = vmovl_s8(vget_high_s8(_weight_hc)); + float32x4_t _weight_hc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_01))), _descale_hc); + float32x4_t _weight_hc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_01))), _descale_hc); + float32x4_t _weight_hc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_23))), _descale_hc); + float32x4_t _weight_hc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_23))), _descale_hc); + + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc_0, _hidden_state, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); + + weight_hc_int8_ptr += 16; + } + for (; i < num_output; i++) + { + float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); + float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc); + _rnn_H = vfmaq_f32(_rnn_H, _weight_hc, _hidden_state); + + weight_hc_int8_ptr += 4; + } + + _rnn_H = vaddq_f32(_rnn_H, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _rnn_H = vaddq_f32(_rnn_H, _sum2); + + _rnn_H = tanh_ps(_rnn_H); + + vst1q_f32((float*)gates + q, _rnn_H); + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4 + q % 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4 + q % 4); + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4 + q % 4); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4 + q % 4); + + const float descale_xc = weight_xc_int8_descales_ptr[0]; + const float descale_hc = weight_hc_int8_descales_ptr[0]; + + float H = (float)(((const __fp16*)bias_c)[q]); + + for (int i = 0; i < size; i++) + { + H += weight_xc_int8_ptr[i] * descale_xc * (float)x[i]; + } + + for (int i = 0; i < num_output; i++) + { + H += weight_hc_int8_ptr[i] * descale_hc * hidden_state[i]; + } + + H = tanhf(H); + + gates[q] = H; + } + + __fp16* output_data = top_blob.row<__fp16>(ti); + + float* hidden_ptr = hidden_state; + + nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + float32x4_t _rnn_H = vld1q_f32((float*)gates + q); + + vst1q_f32(hidden_ptr + q, _rnn_H); + vst1_f16(output_data + q, vcvt_f16_f32(_rnn_H)); + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + float H = gates[q]; + + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + } + + return 0; +} + +static int rnn_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + + // num_output + Mat gates(num_output, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + const __fp16* x = bottom_blob.row(ti); + + int nn_num_output = num_output >> 3; + int remain_num_output_start = nn_num_output << 3; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 8; + + // const __fp16* weight_xc_ptr = weight_xc.row(q / 8); + // const __fp16* weight_hc_ptr = weight_hc.row(q / 8); + + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 8); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 8); + + const __fp16* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 8); + const __fp16* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 8); + + float16x8_t _descale_xc = vld1q_f16(weight_xc_int8_descales_ptr); + float16x8_t _descale_hc = vld1q_f16(weight_hc_int8_descales_ptr); + + float16x8_t _rnn_H = vld1q_f16((const __fp16*)bias_c + q); + float16x8_t _sum1 = vdupq_n_f16(0.f); + float16x8_t _sum2 = vdupq_n_f16(0.f); + float16x8_t _sum3 = vdupq_n_f16(0.f); + + int i = 0; + for (; i + 3 < size; i += 4) + { + float16x4_t _x = vld1_f16(x + i); + + int8x16_t _weight_xc_01 = vld1q_s8(weight_xc_int8_ptr); + int8x16_t _weight_xc_23 = vld1q_s8(weight_xc_int8_ptr + 16); + + float16x8_t _weight_xc_0 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_01))), _descale_xc); + float16x8_t _weight_xc_1 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_01))), _descale_xc); + float16x8_t _weight_xc_2 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_23))), _descale_xc); + float16x8_t _weight_xc_3 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_23))), _descale_xc); + + _rnn_H = vfmaq_lane_f16(_rnn_H, _weight_xc_0, _x, 0); + _sum1 = vfmaq_lane_f16(_sum1, _weight_xc_1, _x, 1); + _sum2 = vfmaq_lane_f16(_sum2, _weight_xc_2, _x, 2); + _sum3 = vfmaq_lane_f16(_sum3, _weight_xc_3, _x, 3); + + weight_xc_int8_ptr += 32; + } + for (; i < size; i++) + { + float16x8_t _x = vdupq_n_f16(x[i]); + float16x8_t _weight_xc = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))), _descale_xc); + _rnn_H = vfmaq_f16(_rnn_H, _weight_xc, _x); + + weight_xc_int8_ptr += 8; + } + + i = 0; + for (; i + 3 < num_output; i += 4) + { + float16x4_t _hidden_state = vcvt_f16_f32(vld1q_f32((const float*)hidden_state + i)); + + int8x16_t _weight_hc_01 = vld1q_s8(weight_hc_int8_ptr); + int8x16_t _weight_hc_23 = vld1q_s8(weight_hc_int8_ptr + 16); + + float16x8_t _weight_hc_0 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_01))), _descale_hc); + float16x8_t _weight_hc_1 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_01))), _descale_hc); + float16x8_t _weight_hc_2 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_23))), _descale_hc); + float16x8_t _weight_hc_3 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_23))), _descale_hc); + + _rnn_H = vfmaq_lane_f16(_rnn_H, _weight_hc_0, _hidden_state, 0); + _sum1 = vfmaq_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1); + _sum2 = vfmaq_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2); + _sum3 = vfmaq_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3); + + weight_hc_int8_ptr += 32; + } + for (; i < num_output; i++) + { + float16x8_t _hidden_state = vdupq_n_f16((__fp16)hidden_state[i]); + float16x8_t _weight_hc = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))), _descale_hc); + _rnn_H = vfmaq_f16(_rnn_H, _weight_hc, _hidden_state); + + weight_hc_int8_ptr += 8; + } + + _rnn_H = vaddq_f16(_rnn_H, _sum1); + _sum2 = vaddq_f16(_sum2, _sum3); + _rnn_H = vaddq_f16(_rnn_H, _sum2); + + float32x4_t _H32low = tanh_ps(vcvt_f32_f16(vget_low_f16(_rnn_H))); + float32x4_t _H32high = tanh_ps(vcvt_f32_f16(vget_high_f16(_rnn_H))); + + vst1q_f32((float*)gates + q, _H32low); + vst1q_f32((float*)gates + q + 4, _H32high); + } + nn_num_output = (num_output - remain_num_output_start) >> 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = remain_num_output_start + qq * 4; + + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 8 + (q % 8) / 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 8 + (q % 8) / 4); + const __fp16* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 8 + (q % 8) / 4); + const __fp16* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 8 + (q % 8) / 4); + + float16x4_t _descale_xc = vld1_f16(weight_xc_int8_descales_ptr); + float16x4_t _descale_hc = vld1_f16(weight_hc_int8_descales_ptr); + float16x8_t _descale_xc_2 = vcombine_f16(_descale_xc, _descale_xc); + float16x8_t _descale_hc_2 = vcombine_f16(_descale_hc, _descale_hc); + + float16x4_t _rnn_H = vld1_f16((const __fp16*)bias_c + q); + float16x4_t _sum1 = vdup_n_f16(0.f); + float16x4_t _sum2 = vdup_n_f16(0.f); + float16x4_t _sum3 = vdup_n_f16(0.f); + + int i = 0; + for (; i + 3 < size; i += 4) + { + float16x4_t _x = vld1_f16(x + i); + + int8x16_t _weight_xc = vld1q_s8(weight_xc_int8_ptr); + float16x8_t _weight_xc_01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc))), _descale_xc_2); + float16x8_t _weight_xc_23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc))), _descale_xc_2); + float16x4_t _weight_xc_0 = vget_low_f16(_weight_xc_01); + float16x4_t _weight_xc_1 = vget_high_f16(_weight_xc_01); + float16x4_t _weight_xc_2 = vget_low_f16(_weight_xc_23); + float16x4_t _weight_xc_3 = vget_high_f16(_weight_xc_23); + + _rnn_H = vfma_lane_f16(_rnn_H, _weight_xc_0, _x, 0); + _sum1 = vfma_lane_f16(_sum1, _weight_xc_1, _x, 1); + _sum2 = vfma_lane_f16(_sum2, _weight_xc_2, _x, 2); + _sum3 = vfma_lane_f16(_sum3, _weight_xc_3, _x, 3); + + weight_xc_int8_ptr += 16; + } + for (; i < size; i++) + { + float16x4_t _x = vdup_n_f16(x[i]); + float16x4_t _weight_xc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr)))), _descale_xc); + _rnn_H = vfma_f16(_rnn_H, _weight_xc, _x); + + weight_xc_int8_ptr += 4; + } + + i = 0; + for (; i + 3 < num_output; i += 4) + { + float16x4_t _hidden_state = vcvt_f16_f32(vld1q_f32((const float*)hidden_state + i)); + + int8x16_t _weight_hc = vld1q_s8(weight_hc_int8_ptr); + float16x8_t _weight_hc_01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc))), _descale_hc_2); + float16x8_t _weight_hc_23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc))), _descale_hc_2); + float16x4_t _weight_hc_0 = vget_low_f16(_weight_hc_01); + float16x4_t _weight_hc_1 = vget_high_f16(_weight_hc_01); + float16x4_t _weight_hc_2 = vget_low_f16(_weight_hc_23); + float16x4_t _weight_hc_3 = vget_high_f16(_weight_hc_23); + + _rnn_H = vfma_lane_f16(_rnn_H, _weight_hc_0, _hidden_state, 0); + _sum1 = vfma_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1); + _sum2 = vfma_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2); + _sum3 = vfma_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3); + + weight_hc_int8_ptr += 16; + } + for (; i < num_output; i++) + { + float16x4_t _hidden_state = vdup_n_f16((__fp16)hidden_state[i]); + float16x4_t _weight_hc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr)))), _descale_hc); + _rnn_H = vfma_f16(_rnn_H, _weight_hc, _hidden_state); + + weight_hc_int8_ptr += 4; + } + + _rnn_H = vadd_f16(_rnn_H, _sum1); + _sum2 = vadd_f16(_sum2, _sum3); + _rnn_H = vadd_f16(_rnn_H, _sum2); + + float32x4_t _H32 = tanh_ps(vcvt_f32_f16(_rnn_H)); + + vst1q_f32((float*)gates + q, _H32); + } + remain_num_output_start += nn_num_output << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 8 + (q % 8) / 4 + q % 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 8 + (q % 8) / 4 + q % 4); + const __fp16* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 8 + (q % 8) / 4 + q % 4); + const __fp16* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 8 + (q % 8) / 4 + q % 4); + + const __fp16 descale_xc = weight_xc_int8_descales_ptr[0]; + const __fp16 descale_hc = weight_hc_int8_descales_ptr[0]; + + __fp16 H = ((const __fp16*)bias_c)[q]; + + for (int i = 0; i < size; i++) + { + H += (__fp16)weight_xc_int8_ptr[i] * descale_xc * x[i]; + } + + for (int i = 0; i < num_output; i++) + { + H += (__fp16)weight_hc_int8_ptr[i] * descale_hc * (__fp16)hidden_state[i]; + } + + float H32 = tanhf((float)H); + + gates[q] = H32; + } + + __fp16* output_data = top_blob.row<__fp16>(ti); + + float* hidden_ptr = hidden_state; + + nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + float32x4_t _rnn_H = vld1q_f32((float*)gates + q); + + vst1q_f32(hidden_ptr + q, _rnn_H); + vst1_f16(output_data + q, vcvt_f16_f32(_rnn_H)); + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + float H = gates[q]; + + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + } + + return 0; +} +#endif // NCNN_INT8 + int RNN_arm::create_pipeline_fp16s(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / num_output; + + if (opt.use_fp16_arithmetic) + { + weight_xc_data_packed.create(size * 8, num_output / 8 + (num_output % 8) / 4 + num_output % 4, num_directions, 1u, 1); + weight_hc_data_packed.create(num_output * 8, num_output / 8 + (num_output % 8) / 4 + num_output % 4, num_directions, 1u, 1); + weight_xc_data_int8_descales_packed.create(8, num_output / 8 + (num_output % 8) / 4 + num_output % 4, num_directions, 2u, 1); + weight_hc_data_int8_descales_packed.create(8, num_output / 8 + (num_output % 8) / 4 + num_output % 4, num_directions, 2u, 1); + } + else + { + weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions, 1u, 1); + weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions, 1u, 1); + weight_xc_data_int8_descales_packed.create(4, num_output / 4 + num_output % 4, num_directions, 4u, 1); + weight_hc_data_int8_descales_packed.create(4, num_output / 4 + num_output % 4, num_directions, 4u, 1); + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc = weight_xc_data.channel(dr); + const Mat weight_hc = weight_hc_data.channel(dr); + const float* weight_xc_int8_scales = weight_xc_data_int8_scales.row(dr); + const float* weight_hc_int8_scales = weight_hc_data_int8_scales.row(dr); + + Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); + Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); + Mat weight_xc_data_int8_descales_packed_dr = weight_xc_data_int8_descales_packed.channel(dr); + Mat weight_hc_data_int8_descales_packed_dr = weight_hc_data_int8_descales_packed.channel(dr); + + int q = 0; + if (opt.use_fp16_arithmetic) + { + for (; q + 7 < num_output; q += 8) + { + const signed char* weight_xc_0 = weight_xc.row(q); + const signed char* weight_xc_1 = weight_xc.row(q + 1); + const signed char* weight_xc_2 = weight_xc.row(q + 2); + const signed char* weight_xc_3 = weight_xc.row(q + 3); + const signed char* weight_xc_4 = weight_xc.row(q + 4); + const signed char* weight_xc_5 = weight_xc.row(q + 5); + const signed char* weight_xc_6 = weight_xc.row(q + 6); + const signed char* weight_xc_7 = weight_xc.row(q + 7); + + const signed char* weight_hc_0 = weight_hc.row(q); + const signed char* weight_hc_1 = weight_hc.row(q + 1); + const signed char* weight_hc_2 = weight_hc.row(q + 2); + const signed char* weight_hc_3 = weight_hc.row(q + 3); + const signed char* weight_hc_4 = weight_hc.row(q + 4); + const signed char* weight_hc_5 = weight_hc.row(q + 5); + const signed char* weight_hc_6 = weight_hc.row(q + 6); + const signed char* weight_hc_7 = weight_hc.row(q + 7); + + signed char* weight_xc_ptr = weight_xc_data_packed_dr.row(q / 8); + signed char* weight_hc_ptr = weight_hc_data_packed_dr.row(q / 8); + __fp16* weight_xc_int8_descales_ptr = weight_xc_data_int8_descales_packed_dr.row<__fp16>(q / 8); + __fp16* weight_hc_int8_descales_ptr = weight_hc_data_int8_descales_packed_dr.row<__fp16>(q / 8); + + for (int i = 0; i < size; i++) + { + weight_xc_ptr[0] = weight_xc_0[i]; + weight_xc_ptr[1] = weight_xc_1[i]; + weight_xc_ptr[2] = weight_xc_2[i]; + weight_xc_ptr[3] = weight_xc_3[i]; + weight_xc_ptr[4] = weight_xc_4[i]; + weight_xc_ptr[5] = weight_xc_5[i]; + weight_xc_ptr[6] = weight_xc_6[i]; + weight_xc_ptr[7] = weight_xc_7[i]; + + weight_xc_ptr += 8; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_ptr[0] = weight_hc_0[i]; + weight_hc_ptr[1] = weight_hc_1[i]; + weight_hc_ptr[2] = weight_hc_2[i]; + weight_hc_ptr[3] = weight_hc_3[i]; + weight_hc_ptr[4] = weight_hc_4[i]; + weight_hc_ptr[5] = weight_hc_5[i]; + weight_hc_ptr[6] = weight_hc_6[i]; + weight_hc_ptr[7] = weight_hc_7[i]; + + weight_hc_ptr += 8; + } + + float32x4_t _xc0 = vld1q_f32(weight_xc_int8_scales + q); + float32x4_t _xc1 = vld1q_f32(weight_xc_int8_scales + q + 4); + float32x4_t _hc0 = vld1q_f32(weight_hc_int8_scales + q); + float32x4_t _hc1 = vld1q_f32(weight_hc_int8_scales + q + 4); + + float32x4_t _one = vdupq_n_f32(1.f); + float16x4_t _reciprocal_xc0 = vcvt_f16_f32(vdivq_f32(_one, _xc0)); + float16x4_t _reciprocal_xc1 = vcvt_f16_f32(vdivq_f32(_one, _xc1)); + float16x4_t _reciprocal_hc0 = vcvt_f16_f32(vdivq_f32(_one, _hc0)); + float16x4_t _reciprocal_hc1 = vcvt_f16_f32(vdivq_f32(_one, _hc1)); + + vst1q_f16(weight_xc_int8_descales_ptr, vcombine_f16(_reciprocal_xc0, _reciprocal_xc1)); + vst1q_f16(weight_hc_int8_descales_ptr, vcombine_f16(_reciprocal_hc0, _reciprocal_hc1)); + } + } + for (; q + 3 < num_output; q += 4) + { + const signed char* weight_xc_0 = weight_xc.row(q); + const signed char* weight_xc_1 = weight_xc.row(q + 1); + const signed char* weight_xc_2 = weight_xc.row(q + 2); + const signed char* weight_xc_3 = weight_xc.row(q + 3); + + const signed char* weight_hc_0 = weight_hc.row(q); + const signed char* weight_hc_1 = weight_hc.row(q + 1); + const signed char* weight_hc_2 = weight_hc.row(q + 2); + const signed char* weight_hc_3 = weight_hc.row(q + 3); + + int qq = opt.use_fp16_arithmetic ? q / 8 + (q % 8) / 4 : q / 4; + signed char* weight_xc_ptr = weight_xc_data_packed_dr.row(qq); + signed char* weight_hc_ptr = weight_hc_data_packed_dr.row(qq); + float* weight_xc_int8_descales_ptr = weight_xc_data_int8_descales_packed_dr.row(qq); + float* weight_hc_int8_descales_ptr = weight_hc_data_int8_descales_packed_dr.row(qq); + + for (int i = 0; i < size; i++) + { + weight_xc_ptr[0] = weight_xc_0[i]; + weight_xc_ptr[1] = weight_xc_1[i]; + weight_xc_ptr[2] = weight_xc_2[i]; + weight_xc_ptr[3] = weight_xc_3[i]; + + weight_xc_ptr += 4; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_ptr[0] = weight_hc_0[i]; + weight_hc_ptr[1] = weight_hc_1[i]; + weight_hc_ptr[2] = weight_hc_2[i]; + weight_hc_ptr[3] = weight_hc_3[i]; + + weight_hc_ptr += 4; + } + + float32x4_t _xc = vld1q_f32(weight_xc_int8_scales + q); + float32x4_t _hc = vld1q_f32(weight_hc_int8_scales + q); + + float32x4_t _one = vdupq_n_f32(1.f); + float32x4_t _reciprocal_xc = vdivq_f32(_one, _xc); + float32x4_t _reciprocal_hc = vdivq_f32(_one, _hc); + + if (opt.use_fp16_arithmetic) + { + vst1_f16((__fp16*)weight_xc_int8_descales_ptr, vcvt_f16_f32(_reciprocal_xc)); + vst1_f16((__fp16*)weight_hc_int8_descales_ptr, vcvt_f16_f32(_reciprocal_hc)); + } + else + { + vst1q_f32(weight_xc_int8_descales_ptr, _reciprocal_xc); + vst1q_f32(weight_hc_int8_descales_ptr, _reciprocal_hc); + } + } + for (; q < num_output; q++) + { + const signed char* weight_xc_0 = weight_xc.row(q); + const signed char* weight_hc_0 = weight_hc.row(q); + + int qq = opt.use_fp16_arithmetic ? q / 8 + (q % 8) / 4 + q % 4 : q / 4 + q % 4; + signed char* weight_xc_ptr = weight_xc_data_packed_dr.row(qq); + signed char* weight_hc_ptr = weight_hc_data_packed_dr.row(qq); + float* weight_xc_int8_descales_ptr = weight_xc_data_int8_descales_packed_dr.row(qq); + float* weight_hc_int8_descales_ptr = weight_hc_data_int8_descales_packed_dr.row(qq); + + for (int i = 0; i < size; i++) + { + weight_xc_ptr[i] = weight_xc_0[i]; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_ptr[i] = weight_hc_0[i]; + } + + if (opt.use_fp16_arithmetic) + { + ((__fp16*)weight_xc_int8_descales_ptr)[0] = (__fp16)(1.f / weight_xc_int8_scales[q]); + ((__fp16*)weight_hc_int8_descales_ptr)[0] = (__fp16)(1.f / weight_hc_int8_scales[q]); + } + else + { + weight_xc_int8_descales_ptr[0] = 1.f / weight_xc_int8_scales[q]; + weight_hc_int8_descales_ptr[0] = 1.f / weight_hc_int8_scales[q]; + } + } + } + + cast_float32_to_float16(bias_c_data, bias_c_data_packed, opt); + + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + weight_xc_data_int8_scales.release(); + weight_hc_data_int8_scales.release(); + } + + return 0; + } +#endif + int num_directions = direction == 2 ? 2 : 1; int size = weight_data_size / num_directions / num_output; @@ -546,9 +1177,20 @@ int RNN_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_fp16s_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -561,15 +1203,37 @@ int RNN_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = rnn_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_fp16s_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.f); - int ret1 = rnn_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_fp16s_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -616,9 +1280,20 @@ int RNN_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(q); - const signed char* weight_hc_ptr = weight_hc_data.channel(d).row(q); + return 0; +} + +static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + + // num_output + Mat gates(num_output, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; - float* weight_xc_fp32_ptr = weight_xc_data_fp32.channel(d).row(q); - float* weight_hc_fp32_ptr = weight_hc_data_fp32.channel(d).row(q); + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + const float* x = bottom_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + const float* weight_xc_ptr = weight_xc.row(q); + const float* weight_hc_ptr = weight_hc.row(q); - const float descale_xc = 1.f / weight_xc_data_int8_scales.row(d)[q]; - const float descale_hc = 1.f / weight_hc_data_int8_scales.row(d)[q]; + float H = bias_c[q]; - for (int i = 0; i < size; i++) - { - weight_xc_fp32_ptr[i] = weight_xc_ptr[i] * descale_xc; - } + for (int i = 0; i < size; i++) + { + H += weight_xc_ptr[i] * x[i]; + } - for (int i = 0; i < num_output; i++) - { - weight_hc_fp32_ptr[i] = weight_hc_ptr[i] * descale_hc; - } + for (int i = 0; i < num_output; i++) + { + H += weight_hc_ptr[i] * hidden_state[i]; } + + H = tanhf(H); + + gates[q] = H; } - weight_xc_data = weight_xc_data_fp32; - weight_hc_data = weight_hc_data_fp32; + float* output_data = top_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + float H = gates[q]; + + hidden_state[q] = H; + output_data[q] = H; + } } -#endif // NCNN_INT8 return 0; } -static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) +#if NCNN_INT8 +static int rnn_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, Mat& hidden_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; @@ -122,19 +147,22 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_output; q++) { - const float* weight_xc_ptr = weight_xc.row(q); - const float* weight_hc_ptr = weight_hc.row(q); + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q); + + const float descale_xc = 1.f / weight_xc_int8_scales[q]; + const float descale_hc = 1.f / weight_hc_int8_scales[q]; float H = bias_c[q]; for (int i = 0; i < size; i++) { - H += weight_xc_ptr[i] * x[i]; + H += weight_xc_int8_ptr[i] * descale_xc * x[i]; } for (int i = 0; i < num_output; i++) { - H += weight_hc_ptr[i] * hidden_state[i]; + H += weight_hc_int8_ptr[i] * descale_hc * hidden_state[i]; } H = tanhf(H); @@ -155,6 +183,7 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we return 0; } +#endif // NCNN_INT8 int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { @@ -175,9 +204,20 @@ int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -190,15 +230,37 @@ int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const if (top_blob_reverse.empty()) return -100; - int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); - int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -243,9 +305,20 @@ int RNN::forward(const std::vector& bottom_blobs, std::vector& top_blo // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -259,14 +332,36 @@ int RNN::forward(const std::vector& bottom_blobs, std::vector& top_blo return -100; Mat hidden0 = hidden.row_range(0, 1); - int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden0, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); - int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden1, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) diff --git a/tests/test_rnn.cpp b/tests/test_rnn.cpp index f9cb9a5d752..4073802402d 100644 --- a/tests/test_rnn.cpp +++ b/tests/test_rnn.cpp @@ -32,13 +32,13 @@ static int test_rnn(const ncnn::Mat& a, int outch, int direction) int ret = test_layer("RNN", pd, weights, a); if (ret != 0) { - fprintf(stderr, "test_rnn failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_rnn failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; } -int test_rnn_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) +int test_rnn_with_hidden(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -63,13 +63,13 @@ int test_rnn_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) int ret = test_layer("RNN", pd, weights, as, 2); if (ret != 0) { - fprintf(stderr, "test_rnn_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_rnn_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; } -int test_rnn_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +int test_rnn_with_hidden_input(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -94,13 +94,13 @@ int test_rnn_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directio int ret = test_layer("RNN", pd, weights, as, 1); if (ret != 0) { - fprintf(stderr, "test_rnn_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_rnn_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; } -int test_rnn_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +int test_rnn_with_hidden_output(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -121,7 +121,7 @@ int test_rnn_layer_with_hidden_output(const ncnn::Mat& a, int outch, int directi int ret = test_layer("RNN", pd, weights, as, 2); if (ret != 0) { - fprintf(stderr, "test_rnn_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_rnn_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; @@ -144,80 +144,80 @@ static int test_rnn_0() static int test_rnn_1() { return 0 - || test_rnn_layer_with_hidden(RandomMat(4, 4), 1, 2) - || test_rnn_layer_with_hidden(RandomMat(8, 2), 2, 2) - || test_rnn_layer_with_hidden(RandomMat(16, 8), 7, 2) - || test_rnn_layer_with_hidden(RandomMat(17, 8), 8, 2) - || test_rnn_layer_with_hidden(RandomMat(19, 15), 8, 2) - || test_rnn_layer_with_hidden(RandomMat(5, 16), 16, 2) - || test_rnn_layer_with_hidden(RandomMat(3, 16), 8, 2) - || test_rnn_layer_with_hidden(RandomMat(2, 5), 99, 2) - || test_rnn_layer_with_hidden(RandomMat(4, 4), 1, 1) - || test_rnn_layer_with_hidden(RandomMat(8, 2), 2, 1) - || test_rnn_layer_with_hidden(RandomMat(16, 8), 7, 1) - || test_rnn_layer_with_hidden(RandomMat(17, 8), 8, 1) - || test_rnn_layer_with_hidden(RandomMat(19, 15), 8, 1) - || test_rnn_layer_with_hidden(RandomMat(5, 16), 16, 1) - || test_rnn_layer_with_hidden(RandomMat(3, 16), 8, 1) - || test_rnn_layer_with_hidden(RandomMat(2, 5), 99, 1) - || test_rnn_layer_with_hidden(RandomMat(4, 2), 1, 0) - || test_rnn_layer_with_hidden(RandomMat(8, 2), 2, 0) - || test_rnn_layer_with_hidden(RandomMat(16, 8), 7, 0) - || test_rnn_layer_with_hidden(RandomMat(17, 8), 8, 0) - || test_rnn_layer_with_hidden(RandomMat(19, 15), 8, 0) - || test_rnn_layer_with_hidden(RandomMat(5, 16), 16, 0) - || test_rnn_layer_with_hidden(RandomMat(3, 16), 8, 0) - || test_rnn_layer_with_hidden(RandomMat(2, 5), 17, 0) - - || test_rnn_layer_with_hidden_input(RandomMat(4, 4), 1, 2) - || test_rnn_layer_with_hidden_input(RandomMat(8, 2), 2, 2) - || test_rnn_layer_with_hidden_input(RandomMat(16, 8), 7, 2) - || test_rnn_layer_with_hidden_input(RandomMat(17, 8), 8, 2) - || test_rnn_layer_with_hidden_input(RandomMat(19, 15), 8, 2) - || test_rnn_layer_with_hidden_input(RandomMat(5, 16), 16, 2) - || test_rnn_layer_with_hidden_input(RandomMat(3, 16), 8, 2) - || test_rnn_layer_with_hidden_input(RandomMat(2, 5), 99, 2) - || test_rnn_layer_with_hidden_input(RandomMat(4, 4), 1, 1) - || test_rnn_layer_with_hidden_input(RandomMat(8, 2), 2, 1) - || test_rnn_layer_with_hidden_input(RandomMat(16, 8), 7, 1) - || test_rnn_layer_with_hidden_input(RandomMat(17, 8), 8, 1) - || test_rnn_layer_with_hidden_input(RandomMat(19, 15), 8, 1) - || test_rnn_layer_with_hidden_input(RandomMat(5, 16), 16, 1) - || test_rnn_layer_with_hidden_input(RandomMat(3, 16), 8, 1) - || test_rnn_layer_with_hidden_input(RandomMat(2, 5), 99, 1) - || test_rnn_layer_with_hidden_input(RandomMat(4, 2), 1, 0) - || test_rnn_layer_with_hidden_input(RandomMat(8, 2), 2, 0) - || test_rnn_layer_with_hidden_input(RandomMat(16, 8), 7, 0) - || test_rnn_layer_with_hidden_input(RandomMat(17, 8), 8, 0) - || test_rnn_layer_with_hidden_input(RandomMat(19, 15), 8, 0) - || test_rnn_layer_with_hidden_input(RandomMat(5, 16), 16, 0) - || test_rnn_layer_with_hidden_input(RandomMat(3, 16), 8, 0) - || test_rnn_layer_with_hidden_input(RandomMat(2, 5), 17, 0) - - || test_rnn_layer_with_hidden_output(RandomMat(4, 4), 1, 2) - || test_rnn_layer_with_hidden_output(RandomMat(8, 2), 2, 2) - || test_rnn_layer_with_hidden_output(RandomMat(16, 8), 7, 2) - || test_rnn_layer_with_hidden_output(RandomMat(17, 8), 8, 2) - || test_rnn_layer_with_hidden_output(RandomMat(19, 15), 8, 2) - || test_rnn_layer_with_hidden_output(RandomMat(5, 16), 16, 2) - || test_rnn_layer_with_hidden_output(RandomMat(3, 16), 8, 2) - || test_rnn_layer_with_hidden_output(RandomMat(2, 5), 99, 2) - || test_rnn_layer_with_hidden_output(RandomMat(4, 4), 1, 1) - || test_rnn_layer_with_hidden_output(RandomMat(8, 2), 2, 1) - || test_rnn_layer_with_hidden_output(RandomMat(16, 8), 7, 1) - || test_rnn_layer_with_hidden_output(RandomMat(17, 8), 8, 1) - || test_rnn_layer_with_hidden_output(RandomMat(19, 15), 8, 1) - || test_rnn_layer_with_hidden_output(RandomMat(5, 16), 16, 1) - || test_rnn_layer_with_hidden_output(RandomMat(3, 16), 8, 1) - || test_rnn_layer_with_hidden_output(RandomMat(2, 5), 99, 1) - || test_rnn_layer_with_hidden_output(RandomMat(4, 2), 1, 0) - || test_rnn_layer_with_hidden_output(RandomMat(8, 2), 2, 0) - || test_rnn_layer_with_hidden_output(RandomMat(16, 8), 7, 0) - || test_rnn_layer_with_hidden_output(RandomMat(17, 8), 8, 0) - || test_rnn_layer_with_hidden_output(RandomMat(19, 15), 8, 0) - || test_rnn_layer_with_hidden_output(RandomMat(5, 16), 16, 0) - || test_rnn_layer_with_hidden_output(RandomMat(3, 16), 8, 0) - || test_rnn_layer_with_hidden_output(RandomMat(2, 5), 17, 0); + || test_rnn_with_hidden(RandomMat(4, 4), 1, 2) + || test_rnn_with_hidden(RandomMat(8, 2), 2, 2) + || test_rnn_with_hidden(RandomMat(16, 8), 7, 2) + || test_rnn_with_hidden(RandomMat(17, 8), 8, 2) + || test_rnn_with_hidden(RandomMat(19, 15), 8, 2) + || test_rnn_with_hidden(RandomMat(5, 16), 16, 2) + || test_rnn_with_hidden(RandomMat(3, 16), 8, 2) + || test_rnn_with_hidden(RandomMat(2, 5), 99, 2) + || test_rnn_with_hidden(RandomMat(4, 4), 1, 1) + || test_rnn_with_hidden(RandomMat(8, 2), 2, 1) + || test_rnn_with_hidden(RandomMat(16, 8), 7, 1) + || test_rnn_with_hidden(RandomMat(17, 8), 8, 1) + || test_rnn_with_hidden(RandomMat(19, 15), 8, 1) + || test_rnn_with_hidden(RandomMat(5, 16), 16, 1) + || test_rnn_with_hidden(RandomMat(3, 16), 8, 1) + || test_rnn_with_hidden(RandomMat(2, 5), 99, 1) + || test_rnn_with_hidden(RandomMat(4, 2), 1, 0) + || test_rnn_with_hidden(RandomMat(8, 2), 2, 0) + || test_rnn_with_hidden(RandomMat(16, 8), 7, 0) + || test_rnn_with_hidden(RandomMat(17, 8), 8, 0) + || test_rnn_with_hidden(RandomMat(19, 15), 8, 0) + || test_rnn_with_hidden(RandomMat(5, 16), 16, 0) + || test_rnn_with_hidden(RandomMat(3, 16), 8, 0) + || test_rnn_with_hidden(RandomMat(2, 5), 17, 0) + + || test_rnn_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_rnn_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_rnn_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_rnn_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_rnn_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_rnn_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_rnn_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_rnn_with_hidden_input(RandomMat(2, 5), 99, 2) + || test_rnn_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_rnn_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_rnn_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_rnn_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_rnn_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_rnn_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_rnn_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_rnn_with_hidden_input(RandomMat(2, 5), 99, 1) + || test_rnn_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_rnn_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_rnn_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_rnn_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_rnn_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_rnn_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_rnn_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_rnn_with_hidden_input(RandomMat(2, 5), 17, 0) + + || test_rnn_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_rnn_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_rnn_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_rnn_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_rnn_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_rnn_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_rnn_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_rnn_with_hidden_output(RandomMat(2, 5), 99, 2) + || test_rnn_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_rnn_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_rnn_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_rnn_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_rnn_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_rnn_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_rnn_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_rnn_with_hidden_output(RandomMat(2, 5), 99, 1) + || test_rnn_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_rnn_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_rnn_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_rnn_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_rnn_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_rnn_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_rnn_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_rnn_with_hidden_output(RandomMat(2, 5), 17, 0); } static int test_rnn_2() @@ -248,8 +248,273 @@ static int test_rnn_3() || test_rnn(RandomMat(2, 5), 17, 1); } +#if NCNN_INT8 +static int test_rnn_int8(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomS8Mat(outch * outch * num_directions); + weights[3] = RandomMat(outch * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * num_directions, 100.f, 200.f); + + int ret = test_layer("RNN", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_rnn_int8 failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +int test_rnn_int8_with_hidden(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomS8Mat(outch * outch * num_directions); + weights[3] = RandomMat(outch * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * num_directions, 100.f, 200.f); + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + std::vector as(2); + as[0] = a; + as[1] = hidden; + + int ret = test_layer("RNN", pd, weights, as, 2); + if (ret != 0) + { + fprintf(stderr, "test_rnn_int8_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +int test_rnn_int8_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomS8Mat(outch * outch * num_directions); + weights[3] = RandomMat(outch * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * num_directions, 100.f, 200.f); + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + std::vector as(2); + as[0] = a; + as[1] = hidden; + + int ret = test_layer("RNN", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_rnn_int8_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +int test_rnn_int8_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomS8Mat(outch * outch * num_directions); + weights[3] = RandomMat(outch * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * num_directions, 100.f, 200.f); + + std::vector as(1); + as[0] = a; + + int ret = test_layer("RNN", pd, weights, as, 2); + if (ret != 0) + { + fprintf(stderr, "test_rnn_int8_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +static int test_rnn_4() +{ + return 0 + || test_rnn_int8(RandomMat(4, 1), 2, 2) + || test_rnn_int8(RandomMat(8, 2), 2, 2) + || test_rnn_int8(RandomMat(16, 8), 7, 2) + || test_rnn_int8(RandomMat(17, 8), 8, 2) + || test_rnn_int8(RandomMat(19, 15), 8, 2) + || test_rnn_int8(RandomMat(5, 16), 16, 2) + || test_rnn_int8(RandomMat(3, 16), 8, 2) + || test_rnn_int8(RandomMat(8, 16), 16, 2) + || test_rnn_int8(RandomMat(2, 5), 17, 2); +} + +static int test_rnn_5() +{ + return 0 + || test_rnn_int8_with_hidden(RandomMat(4, 4), 1, 2) + || test_rnn_int8_with_hidden(RandomMat(8, 2), 2, 2) + || test_rnn_int8_with_hidden(RandomMat(16, 8), 7, 2) + || test_rnn_int8_with_hidden(RandomMat(17, 8), 8, 2) + || test_rnn_int8_with_hidden(RandomMat(19, 15), 8, 2) + || test_rnn_int8_with_hidden(RandomMat(5, 16), 16, 2) + || test_rnn_int8_with_hidden(RandomMat(3, 16), 8, 2) + || test_rnn_int8_with_hidden(RandomMat(2, 5), 99, 2) + || test_rnn_int8_with_hidden(RandomMat(4, 4), 1, 1) + || test_rnn_int8_with_hidden(RandomMat(8, 2), 2, 1) + || test_rnn_int8_with_hidden(RandomMat(16, 8), 7, 1) + || test_rnn_int8_with_hidden(RandomMat(17, 8), 8, 1) + || test_rnn_int8_with_hidden(RandomMat(19, 15), 8, 1) + || test_rnn_int8_with_hidden(RandomMat(5, 16), 16, 1) + || test_rnn_int8_with_hidden(RandomMat(3, 16), 8, 1) + || test_rnn_int8_with_hidden(RandomMat(2, 5), 99, 1) + || test_rnn_int8_with_hidden(RandomMat(4, 2), 1, 0) + || test_rnn_int8_with_hidden(RandomMat(8, 2), 2, 0) + || test_rnn_int8_with_hidden(RandomMat(16, 8), 7, 0) + || test_rnn_int8_with_hidden(RandomMat(17, 8), 8, 0) + || test_rnn_int8_with_hidden(RandomMat(19, 15), 8, 0) + || test_rnn_int8_with_hidden(RandomMat(5, 16), 16, 0) + || test_rnn_int8_with_hidden(RandomMat(3, 16), 8, 0) + || test_rnn_int8_with_hidden(RandomMat(2, 5), 17, 0) + + || test_rnn_int8_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_rnn_int8_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_rnn_int8_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_rnn_int8_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_rnn_int8_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_rnn_int8_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_rnn_int8_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_rnn_int8_with_hidden_input(RandomMat(2, 5), 99, 2) + || test_rnn_int8_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_rnn_int8_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_rnn_int8_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_rnn_int8_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_rnn_int8_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_rnn_int8_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_rnn_int8_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_rnn_int8_with_hidden_input(RandomMat(2, 5), 99, 1) + || test_rnn_int8_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_rnn_int8_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_rnn_int8_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_rnn_int8_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_rnn_int8_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_rnn_int8_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_rnn_int8_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_rnn_int8_with_hidden_input(RandomMat(2, 5), 17, 0) + + || test_rnn_int8_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_rnn_int8_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_rnn_int8_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_rnn_int8_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_rnn_int8_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_rnn_int8_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_rnn_int8_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_rnn_int8_with_hidden_output(RandomMat(2, 5), 99, 2) + || test_rnn_int8_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_rnn_int8_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_rnn_int8_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_rnn_int8_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_rnn_int8_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_rnn_int8_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_rnn_int8_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_rnn_int8_with_hidden_output(RandomMat(2, 5), 99, 1) + || test_rnn_int8_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_rnn_int8_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_rnn_int8_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_rnn_int8_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_rnn_int8_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_rnn_int8_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_rnn_int8_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_rnn_int8_with_hidden_output(RandomMat(2, 5), 17, 0); +} + +static int test_rnn_6() +{ + return 0 + || test_rnn_int8(RandomMat(4, 1), 1, 0) + || test_rnn_int8(RandomMat(8, 2), 2, 0) + || test_rnn_int8(RandomMat(16, 8), 7, 0) + || test_rnn_int8(RandomMat(17, 8), 8, 0) + || test_rnn_int8(RandomMat(19, 15), 8, 0) + || test_rnn_int8(RandomMat(5, 16), 16, 0) + || test_rnn_int8(RandomMat(3, 16), 8, 0) + || test_rnn_int8(RandomMat(8, 16), 16, 0) + || test_rnn_int8(RandomMat(2, 5), 17, 0); +} + +static int test_rnn_7() +{ + return 0 + || test_rnn_int8(RandomMat(4, 1), 1, 1) + || test_rnn_int8(RandomMat(8, 2), 2, 1) + || test_rnn_int8(RandomMat(16, 8), 7, 1) + || test_rnn_int8(RandomMat(17, 8), 8, 1) + || test_rnn_int8(RandomMat(19, 15), 8, 1) + || test_rnn_int8(RandomMat(5, 16), 16, 1) + || test_rnn_int8(RandomMat(3, 16), 8, 1) + || test_rnn_int8(RandomMat(8, 16), 16, 1) + || test_rnn_int8(RandomMat(2, 5), 17, 1); +} +#endif + int main() { SRAND(7767517); - return test_rnn_0() || test_rnn_1() || test_rnn_2() || test_rnn_3(); + +#if NCNN_INT8 + return 0 + || test_rnn_0() + || test_rnn_1() + || test_rnn_2() + || test_rnn_3() + || test_rnn_4() + || test_rnn_5() + || test_rnn_6() + || test_rnn_7(); +#else + return 0 + || test_rnn_0() + || test_rnn_1() + || test_rnn_2() + || test_rnn_3(); +#endif }