diff --git a/src/layer/arm/gru_arm.cpp b/src/layer/arm/gru_arm.cpp index a1a53903887..8ef7c455734 100644 --- a/src/layer/arm/gru_arm.cpp +++ b/src/layer/arm/gru_arm.cpp @@ -25,6 +25,10 @@ namespace ncnn { +#if NCNN_INT8 +#include "gru_int8.h" +#endif + GRU_arm::GRU_arm() { #if __ARM_NEON @@ -40,6 +44,13 @@ GRU_arm::GRU_arm() int GRU_arm::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + return create_pipeline_int8(opt); + } +#endif + #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage) { @@ -54,13 +65,6 @@ int GRU_arm::create_pipeline(const Option& opt) } #endif -#if NCNN_INT8 - if (int8_scale_term) - { - return create_pipeline_int8(opt); - } -#endif - // pack RUN const int num_directions = direction == 2 ? 2 : 1; const int size = weight_data_size / num_directions / num_output / 3; @@ -630,8 +634,186 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we return 0; } +int GRU_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blob, top_blob, opt); + } +#endif + + int elembits = bottom_blob.elembits(); + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + return forward_fp16s(bottom_blob, top_blob, opt); +#endif + +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_bf16s(bottom_blob, top_blob, opt); +#endif + + int T = bottom_blob.h; + + int num_directions = direction == 2 ? 2 : 1; + + // initial hidden state + Mat hidden(num_output, 4u, opt.workspace_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // Uni directional + if (direction == 0 || direction == 1) + { + int ret = gru(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + { + int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + + hidden.fill(0.0f); + + { + int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + return 0; +} + +int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ #if NCNN_INT8 -static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + + const Mat& bottom_blob = bottom_blobs[0]; + int elembits = bottom_blob.elembits(); + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + return forward_fp16s(bottom_blobs, top_blobs, opt); +#endif + +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_bf16s(bottom_blobs, top_blobs, opt); +#endif + + int T = bottom_blob.h; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 2) + { + hidden = bottom_blobs[1].clone(hidden_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // Uni directional + if (direction == 0 || direction == 1) + { + int ret = gru(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + { + int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); + if (ret != 0) + return ret; + } + + Mat hidden1 = hidden.row_range(1, 1); + { + int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); + if (ret != 0) + return ret; + } + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 2) + { + top_blobs[1] = hidden; + } + + return 0; +} + +#if NCNN_BF16 +static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; @@ -662,24 +844,16 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma { int q = qq * 4; - const float* x = bottom_blob.row(ti); + const unsigned short* x = bottom_blob.row(ti); // gate reset update - const float* bias_c_RUBNWN = (const float*)bias_c + q * 4; - - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4); - - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4); + const unsigned short* bias_c_RUBNWN = (const unsigned short*)bias_c + q * 4; - float32x4_t _descale_xc_R = vld1q_f32(weight_xc_int8_descales_RUN); - float32x4_t _descale_xc_U = vld1q_f32(weight_xc_int8_descales_RUN + 4); - float32x4_t _descale_hc_R = vld1q_f32(weight_hc_int8_descales_RUN); - float32x4_t _descale_hc_U = vld1q_f32(weight_hc_int8_descales_RUN + 4); + const unsigned short* weight_xc_RUN = weight_xc.row(q / 4); + const unsigned short* weight_hc_RUN = weight_hc.row(q / 4); - float32x4_t _gru_R = vld1q_f32(bias_c_RUBNWN); - float32x4_t _gru_U = vld1q_f32(bias_c_RUBNWN + 4); + float32x4_t _gru_R = bfloat2float(vld1_u16(bias_c_RUBNWN)); + float32x4_t _gru_U = bfloat2float(vld1_u16(bias_c_RUBNWN + 4)); float32x4_t _sum1 = vdupq_n_f32(0.f); float32x4_t _sum2 = vdupq_n_f32(0.f); float32x4_t _sum3 = vdupq_n_f32(0.f); @@ -690,25 +864,15 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma int i = 0; for (; i + 3 < size; i += 4) { - float32x4_t _xi = vld1q_f32(x + i); - - int8x16_t _weight_xc_RU01 = vld1q_s8(weight_xc_int8_RUN); - int8x16_t _weight_xc_RU23 = vld1q_s8(weight_xc_int8_RUN + 16); - - int16x8_t _weight_xc_RU0 = vmovl_s8(vget_low_s8(_weight_xc_RU01)); - int16x8_t _weight_xc_RU1 = vmovl_s8(vget_high_s8(_weight_xc_RU01)); - int16x8_t _weight_xc_RU2 = vmovl_s8(vget_low_s8(_weight_xc_RU23)); - int16x8_t _weight_xc_RU3 = vmovl_s8(vget_high_s8(_weight_xc_RU23)); - - float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU0))), _descale_xc_R); - float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU0))), _descale_xc_U); - float32x4_t _weight_xc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU1))), _descale_xc_R); - float32x4_t _weight_xc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU1))), _descale_xc_U); - float32x4_t _weight_xc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU2))), _descale_xc_R); - float32x4_t _weight_xc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU2))), _descale_xc_U); - float32x4_t _weight_xc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU3))), _descale_xc_R); - float32x4_t _weight_xc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU3))), _descale_xc_U); - + float32x4_t _xi = bfloat2float(vld1_u16(x + i)); + float32x4_t _weight_xc_R = bfloat2float(vld1_u16(weight_xc_RUN)); + float32x4_t _weight_xc_U = bfloat2float(vld1_u16(weight_xc_RUN + 4)); + float32x4_t _weight_xc_R_1 = bfloat2float(vld1_u16(weight_xc_RUN + 8)); + float32x4_t _weight_xc_U_1 = bfloat2float(vld1_u16(weight_xc_RUN + 12)); + float32x4_t _weight_xc_R_2 = bfloat2float(vld1_u16(weight_xc_RUN + 16)); + float32x4_t _weight_xc_U_2 = bfloat2float(vld1_u16(weight_xc_RUN + 20)); + float32x4_t _weight_xc_R_3 = bfloat2float(vld1_u16(weight_xc_RUN + 24)); + float32x4_t _weight_xc_U_3 = bfloat2float(vld1_u16(weight_xc_RUN + 28)); #if __aarch64__ _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); @@ -729,46 +893,33 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma _sum6 = vmlaq_lane_f32(_sum6, _weight_xc_U_3, vget_high_f32(_xi), 1); #endif - weight_xc_int8_RUN += 32; + weight_xc_RUN += 32; } for (; i < size; i++) { - float xi = x[i]; - - float32x4_t _xi = vdupq_n_f32(xi); - - int16x8_t _weight_xc_RU = vmovl_s8(vld1_s8(weight_xc_int8_RUN)); - float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU))), _descale_xc_R); - float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU))), _descale_xc_U); + unsigned short xi = x[i]; + float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); + float32x4_t _weight_xc_R = bfloat2float(vld1_u16(weight_xc_RUN)); + float32x4_t _weight_xc_U = bfloat2float(vld1_u16(weight_xc_RUN + 4)); _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); - weight_xc_int8_RUN += 8; + weight_xc_RUN += 8; } i = 0; for (; i + 3 < num_output; i += 4) { float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - - int8x16_t _weight_hc_RU01 = vld1q_s8(weight_hc_int8_RUN); - int8x16_t _weight_hc_RU23 = vld1q_s8(weight_hc_int8_RUN + 16); - - int16x8_t _weight_hc_RU0 = vmovl_s8(vget_low_s8(_weight_hc_RU01)); - int16x8_t _weight_hc_RU1 = vmovl_s8(vget_high_s8(_weight_hc_RU01)); - int16x8_t _weight_hc_RU2 = vmovl_s8(vget_low_s8(_weight_hc_RU23)); - int16x8_t _weight_hc_RU3 = vmovl_s8(vget_high_s8(_weight_hc_RU23)); - - float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU0))), _descale_hc_R); - float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU0))), _descale_hc_U); - float32x4_t _weight_hc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU1))), _descale_hc_R); - float32x4_t _weight_hc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU1))), _descale_hc_U); - float32x4_t _weight_hc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU2))), _descale_hc_R); - float32x4_t _weight_hc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU2))), _descale_hc_U); - float32x4_t _weight_hc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU3))), _descale_hc_R); - float32x4_t _weight_hc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU3))), _descale_hc_U); - + float32x4_t _weight_hc_R = bfloat2float(vld1_u16(weight_hc_RUN)); + float32x4_t _weight_hc_U = bfloat2float(vld1_u16(weight_hc_RUN + 4)); + float32x4_t _weight_hc_R_1 = bfloat2float(vld1_u16(weight_hc_RUN + 8)); + float32x4_t _weight_hc_U_1 = bfloat2float(vld1_u16(weight_hc_RUN + 12)); + float32x4_t _weight_hc_R_2 = bfloat2float(vld1_u16(weight_hc_RUN + 16)); + float32x4_t _weight_hc_U_2 = bfloat2float(vld1_u16(weight_hc_RUN + 20)); + float32x4_t _weight_hc_R_3 = bfloat2float(vld1_u16(weight_hc_RUN + 24)); + float32x4_t _weight_hc_U_3 = bfloat2float(vld1_u16(weight_hc_RUN + 28)); #if __aarch64__ _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); @@ -789,22 +940,19 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma _sum6 = vmlaq_lane_f32(_sum6, _weight_hc_U_3, vget_high_f32(_h_cont), 1); #endif - weight_hc_int8_RUN += 32; + weight_hc_RUN += 32; } for (; i < num_output; i++) { float h_cont = hidden_state[i]; float32x4_t _h_cont = vdupq_n_f32(h_cont); - - int16x8_t _weight_hc_RU = vmovl_s8(vld1_s8(weight_hc_int8_RUN)); - float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU))), _descale_hc_R); - float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU))), _descale_hc_U); - + float32x4_t _weight_hc_R = bfloat2float(vld1_u16(weight_hc_RUN)); + float32x4_t _weight_hc_U = bfloat2float(vld1_u16(weight_hc_RUN + 4)); _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); - weight_hc_int8_RUN += 8; + weight_hc_RUN += 8; } _gru_R = vaddq_f32(_gru_R, _sum1); @@ -820,27 +968,19 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma _gru_U = sigmoid_ps(_gru_U); // gate new - float32x4_t _gru_N = vld1q_f32(bias_c_RUBNWN + 8); + float32x4_t _gru_N = bfloat2float(vld1_u16(bias_c_RUBNWN + 8)); _sum1 = vdupq_n_f32(0.f); _sum2 = vdupq_n_f32(0.f); _sum3 = vdupq_n_f32(0.f); - float32x4_t _descale_xc_N = vld1q_f32(weight_xc_int8_descales_RUN + 8); - float32x4_t _descale_hc_N = vld1q_f32(weight_hc_int8_descales_RUN + 8); - i = 0; for (; i + 3 < num_output; i += 4) { float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - - int8x16_t _weight_hc_N0123 = vld1q_s8(weight_hc_int8_RUN); - int16x8_t _weight_hc_N01 = vmovl_s8(vget_low_s8(_weight_hc_N0123)); - int16x8_t _weight_hc_N23 = vmovl_s8(vget_high_s8(_weight_hc_N0123)); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N01))), _descale_hc_N); - float32x4_t _weight_hc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N01))), _descale_hc_N); - float32x4_t _weight_hc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N23))), _descale_hc_N); - float32x4_t _weight_hc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N23))), _descale_hc_N); - + float32x4_t _weight_hc_N = bfloat2float(vld1_u16(weight_hc_RUN)); + float32x4_t _weight_hc_N_1 = bfloat2float(vld1_u16(weight_hc_RUN + 4)); + float32x4_t _weight_hc_N_2 = bfloat2float(vld1_u16(weight_hc_RUN + 8)); + float32x4_t _weight_hc_N_3 = bfloat2float(vld1_u16(weight_hc_RUN + 12)); #if __aarch64__ _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); @@ -853,25 +993,24 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_N_3, vget_high_f32(_h_cont), 1); #endif - weight_hc_int8_RUN += 16; + weight_hc_RUN += 16; } for (; i < num_output; i++) { float h_cont = hidden_state[i]; float32x4_t _h_cont = vdupq_n_f32(h_cont); - int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0])); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc_N); + float32x4_t _weight_hc_N = bfloat2float(vld1_u16(weight_hc_RUN)); _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); - weight_hc_int8_RUN += 4; + weight_hc_RUN += 4; } _gru_N = vaddq_f32(_gru_N, _sum1); _sum2 = vaddq_f32(_sum2, _sum3); _gru_N = vaddq_f32(_gru_N, _sum2); - _gru_N = vmlaq_f32(vld1q_f32(bias_c_RUBNWN + 12), _gru_R, _gru_N); + _gru_N = vmlaq_f32(bfloat2float(vld1_u16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); _sum1 = vdupq_n_f32(0.f); _sum2 = vdupq_n_f32(0.f); _sum3 = vdupq_n_f32(0.f); @@ -879,16 +1018,11 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma i = 0; for (; i + 3 < size; i += 4) { - float32x4_t _xi = vld1q_f32(x + i); - - int8x16_t _weight_xc_N0123 = vld1q_s8(weight_xc_int8_RUN); - int16x8_t _weight_xc_N01 = vmovl_s8(vget_low_s8(_weight_xc_N0123)); - int16x8_t _weight_xc_N23 = vmovl_s8(vget_high_s8(_weight_xc_N0123)); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N01))), _descale_xc_N); - float32x4_t _weight_xc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N01))), _descale_xc_N); - float32x4_t _weight_xc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N23))), _descale_xc_N); - float32x4_t _weight_xc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N23))), _descale_xc_N); - + float32x4_t _xi = bfloat2float(vld1_u16(x + i)); + float32x4_t _weight_xc_N = bfloat2float(vld1_u16(weight_xc_RUN)); + float32x4_t _weight_xc_N_1 = bfloat2float(vld1_u16(weight_xc_RUN + 4)); + float32x4_t _weight_xc_N_2 = bfloat2float(vld1_u16(weight_xc_RUN + 8)); + float32x4_t _weight_xc_N_3 = bfloat2float(vld1_u16(weight_xc_RUN + 12)); #if __aarch64__ _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); @@ -901,18 +1035,17 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_N_3, vget_high_f32(_xi), 1); #endif - weight_xc_int8_RUN += 16; + weight_xc_RUN += 16; } for (; i < size; i++) { - float xi = x[i]; + unsigned short xi = x[i]; - float32x4_t _xi = vdupq_n_f32(xi); - int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0])); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc_N); + float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); + float32x4_t _weight_xc_N = bfloat2float(vld1_u16(weight_xc_RUN)); _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); - weight_xc_int8_RUN += 4; + weight_xc_RUN += 4; } _gru_N = vaddq_f32(_gru_N, _sum1); @@ -931,52 +1064,40 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_num_output_start; q < num_output; q++) { - const float* x = bottom_blob.row(ti); + const unsigned short* x = bottom_blob.row(ti); // gate reset update - const float* bias_c_RUBNWN = (const float*)bias_c + q * 4; + const unsigned short* bias_c_RUBNWN = (const unsigned short*)bias_c + q * 4; #if __ARM_NEON - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4 + q % 4); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4 + q % 4); - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4 + q % 4); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4 + q % 4); + const unsigned short* weight_xc_RUN = weight_xc.row(q / 4 + q % 4); + const unsigned short* weight_hc_RUN = weight_hc.row(q / 4 + q % 4); #else - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q); - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q); + const unsigned short* weight_xc_RUN = weight_xc.row(q); + const unsigned short* weight_hc_RUN = weight_hc.row(q); #endif - const float descale_xc_R = weight_xc_int8_descales_RUN[0]; - const float descale_xc_U = weight_xc_int8_descales_RUN[1]; - const float descale_xc_N = weight_xc_int8_descales_RUN[2]; - - const float descale_hc_R = weight_hc_int8_descales_RUN[0]; - const float descale_hc_U = weight_hc_int8_descales_RUN[1]; - const float descale_hc_N = weight_hc_int8_descales_RUN[2]; - - float R = bias_c_RUBNWN[0]; - float U = bias_c_RUBNWN[1]; + float R = bfloat16_to_float32(bias_c_RUBNWN[0]); + float U = bfloat16_to_float32(bias_c_RUBNWN[1]); for (int i = 0; i < size; i++) { - float xi = x[i]; + float xi = bfloat16_to_float32(x[i]); - R += weight_xc_int8_RUN[0] * descale_xc_R * xi; - U += weight_xc_int8_RUN[1] * descale_xc_U * xi; + R += bfloat16_to_float32(weight_xc_RUN[0]) * xi; + U += bfloat16_to_float32(weight_xc_RUN[1]) * xi; - weight_xc_int8_RUN += 2; + weight_xc_RUN += 2; } for (int i = 0; i < num_output; i++) { float h_cont = hidden_state[i]; - R += weight_hc_int8_RUN[0] * descale_hc_R * h_cont; - U += weight_hc_int8_RUN[1] * descale_hc_U * h_cont; + R += bfloat16_to_float32(weight_hc_RUN[0]) * h_cont; + U += bfloat16_to_float32(weight_hc_RUN[1]) * h_cont; - weight_hc_int8_RUN += 2; + weight_hc_RUN += 2; } // sigmoid(R) @@ -985,26 +1106,26 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma U = 1.f / (1.f + expf(-U)); // gate new - float N = bias_c_RUBNWN[2]; + float N = bfloat16_to_float32(bias_c_RUBNWN[2]); for (int i = 0; i < num_output; i++) { float h_cont = hidden_state[i]; - N += weight_hc_int8_RUN[0] * descale_hc_N * h_cont; + N += bfloat16_to_float32(weight_hc_RUN[0]) * h_cont; - weight_hc_int8_RUN += 1; + weight_hc_RUN += 1; } - N = bias_c_RUBNWN[3] + R * N; + N = bfloat16_to_float32(bias_c_RUBNWN[3]) + R * N; for (int i = 0; i < size; i++) { - float xi = x[i]; + float xi = bfloat16_to_float32(x[i]); - N += weight_xc_int8_RUN[0] * descale_xc_N * xi; + N += bfloat16_to_float32(weight_xc_RUN[0]) * xi; - weight_xc_int8_RUN += 1; + weight_xc_RUN += 1; } // tanh(N) @@ -1021,7 +1142,7 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma } // h_t := (1 - update) .* new + update .* h_{t-1} - float* output_data = top_blob.row(ti); + unsigned short* output_data = top_blob.row(ti); float* hidden_ptr = hidden_state; @@ -1042,7 +1163,7 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma float32x4_t _gru_H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U), _gru_N), vmulq_f32(_gru_U, vld1q_f32(hidden_ptr + q))); vst1q_f32(hidden_ptr + q, _gru_H); - vst1q_f32(output_data + q, _gru_H); + vst1_u16(output_data + q, float2bfloat(_gru_H)); } #endif // __ARM_NEON #pragma omp parallel for num_threads(opt.num_threads) @@ -1060,31 +1181,27 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma float H = (1 - U) * N + U * hidden_ptr[q]; hidden_ptr[q] = H; - output_data[q] = H; + output_data[q] = float32_to_bfloat16(H); } } return 0; } -int GRU_arm::create_pipeline_int8(const Option& opt) +int GRU_arm::create_pipeline_bf16s(const Option& opt) { // pack RUN - const int num_directions = direction == 2 ? 2 : 1; - const int size = weight_data_size / num_directions / num_output / 3; + int num_directions = direction == 2 ? 2 : 1; + int size = weight_data_size / num_directions / num_output / 3; #if __ARM_NEON - weight_xc_data_packed.create(size * 12, num_output / 4 + num_output % 4, num_directions, 1u, 1); - bias_c_data_packed.create(num_output, 1, num_directions, 16u, 4); - weight_hc_data_packed.create(num_output * 12, num_output / 4 + num_output % 4, num_directions, 1u, 1); - weight_xc_data_int8_descales_packed.create(12, num_output / 4 + num_output % 4, num_directions); - weight_hc_data_int8_descales_packed.create(12, num_output / 4 + num_output % 4, num_directions); + weight_xc_data_packed.create(size * 12, num_output / 4 + num_output % 4, num_directions, 2u, 1); + bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); + weight_hc_data_packed.create(num_output * 12, num_output / 4 + num_output % 4, num_directions, 2u, 1); #else - weight_xc_data_packed.create(size * 3, num_output, num_directions, 1u, 1); - bias_c_data_packed.create(num_output, 1, num_directions, 16u, 4); - weight_hc_data_packed.create(num_output * 3, num_output, num_directions, 1u, 1); - weight_xc_data_int8_descales_packed.create(3, num_output, num_directions); - weight_hc_data_int8_descales_packed.create(3, num_output, num_directions); + weight_xc_data_packed.create(size * 3, num_output, num_directions, 2u, 1); + bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); + weight_hc_data_packed.create(num_output * 3, num_output, num_directions, 2u, 1); #endif #pragma omp parallel for num_threads(opt.num_threads) @@ -1093,223 +1210,167 @@ int GRU_arm::create_pipeline_int8(const Option& opt) const Mat weight_xc = weight_xc_data.channel(dr); const Mat bias_c = bias_c_data.channel(dr); const Mat weight_hc = weight_hc_data.channel(dr); - const float* weight_xc_int8_scales = weight_xc_data_int8_scales.row(dr); - const float* weight_hc_int8_scales = weight_hc_data_int8_scales.row(dr); Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr); Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); - Mat weight_xc_data_int8_descales_packed_dr = weight_xc_data_int8_descales_packed.channel(dr); - Mat weight_hc_data_int8_descales_packed_dr = weight_hc_data_int8_descales_packed.channel(dr); const float* bias_c_R = bias_c.row(0); const float* bias_c_U = bias_c.row(1); const float* bias_c_WN = bias_c.row(2); const float* bias_c_BN = bias_c.row(3); - float* bias_c_RUBNWN = bias_c_data_packed_dr.row(0); + unsigned short* bias_c_RUBNWN = bias_c_data_packed_dr.row(0); int q = 0; #if __ARM_NEON for (; q + 3 < num_output; q += 4) { - vst1q_f32(bias_c_RUBNWN, vld1q_f32(bias_c_R + q)); - vst1q_f32(bias_c_RUBNWN + 4, vld1q_f32(bias_c_U + q)); - vst1q_f32(bias_c_RUBNWN + 8, vld1q_f32(bias_c_BN + q)); - vst1q_f32(bias_c_RUBNWN + 12, vld1q_f32(bias_c_WN + q)); + vst1_u16(bias_c_RUBNWN, float2bfloat(vld1q_f32(bias_c_R + q))); + vst1_u16(bias_c_RUBNWN + 4, float2bfloat(vld1q_f32(bias_c_U + q))); + vst1_u16(bias_c_RUBNWN + 8, float2bfloat(vld1q_f32(bias_c_BN + q))); + vst1_u16(bias_c_RUBNWN + 12, float2bfloat(vld1q_f32(bias_c_WN + q))); bias_c_RUBNWN += 16; - const signed char* weight_xc_R = weight_xc.row(num_output * 0 + q); - const signed char* weight_xc_U = weight_xc.row(num_output * 1 + q); - const signed char* weight_xc_N = weight_xc.row(num_output * 2 + q); + const float* weight_xc_R = weight_xc.row(num_output * 0 + q); + const float* weight_xc_U = weight_xc.row(num_output * 1 + q); + const float* weight_xc_N = weight_xc.row(num_output * 2 + q); - const signed char* weight_xc_R_1 = weight_xc.row(num_output * 0 + q + 1); - const signed char* weight_xc_U_1 = weight_xc.row(num_output * 1 + q + 1); - const signed char* weight_xc_N_1 = weight_xc.row(num_output * 2 + q + 1); + const float* weight_xc_R_1 = weight_xc.row(num_output * 0 + q + 1); + const float* weight_xc_U_1 = weight_xc.row(num_output * 1 + q + 1); + const float* weight_xc_N_1 = weight_xc.row(num_output * 2 + q + 1); - const signed char* weight_xc_R_2 = weight_xc.row(num_output * 0 + q + 2); - const signed char* weight_xc_U_2 = weight_xc.row(num_output * 1 + q + 2); - const signed char* weight_xc_N_2 = weight_xc.row(num_output * 2 + q + 2); + const float* weight_xc_R_2 = weight_xc.row(num_output * 0 + q + 2); + const float* weight_xc_U_2 = weight_xc.row(num_output * 1 + q + 2); + const float* weight_xc_N_2 = weight_xc.row(num_output * 2 + q + 2); - const signed char* weight_xc_R_3 = weight_xc.row(num_output * 0 + q + 3); - const signed char* weight_xc_U_3 = weight_xc.row(num_output * 1 + q + 3); - const signed char* weight_xc_N_3 = weight_xc.row(num_output * 2 + q + 3); + const float* weight_xc_R_3 = weight_xc.row(num_output * 0 + q + 3); + const float* weight_xc_U_3 = weight_xc.row(num_output * 1 + q + 3); + const float* weight_xc_N_3 = weight_xc.row(num_output * 2 + q + 3); - const signed char* weight_hc_R = weight_hc.row(num_output * 0 + q); - const signed char* weight_hc_U = weight_hc.row(num_output * 1 + q); - const signed char* weight_hc_N = weight_hc.row(num_output * 2 + q); + const float* weight_hc_R = weight_hc.row(num_output * 0 + q); + const float* weight_hc_U = weight_hc.row(num_output * 1 + q); + const float* weight_hc_N = weight_hc.row(num_output * 2 + q); - const signed char* weight_hc_R_1 = weight_hc.row(num_output * 0 + q + 1); - const signed char* weight_hc_U_1 = weight_hc.row(num_output * 1 + q + 1); - const signed char* weight_hc_N_1 = weight_hc.row(num_output * 2 + q + 1); + const float* weight_hc_R_1 = weight_hc.row(num_output * 0 + q + 1); + const float* weight_hc_U_1 = weight_hc.row(num_output * 1 + q + 1); + const float* weight_hc_N_1 = weight_hc.row(num_output * 2 + q + 1); - const signed char* weight_hc_R_2 = weight_hc.row(num_output * 0 + q + 2); - const signed char* weight_hc_U_2 = weight_hc.row(num_output * 1 + q + 2); - const signed char* weight_hc_N_2 = weight_hc.row(num_output * 2 + q + 2); + const float* weight_hc_R_2 = weight_hc.row(num_output * 0 + q + 2); + const float* weight_hc_U_2 = weight_hc.row(num_output * 1 + q + 2); + const float* weight_hc_N_2 = weight_hc.row(num_output * 2 + q + 2); - const signed char* weight_hc_R_3 = weight_hc.row(num_output * 0 + q + 3); - const signed char* weight_hc_U_3 = weight_hc.row(num_output * 1 + q + 3); - const signed char* weight_hc_N_3 = weight_hc.row(num_output * 2 + q + 3); + const float* weight_hc_R_3 = weight_hc.row(num_output * 0 + q + 3); + const float* weight_hc_U_3 = weight_hc.row(num_output * 1 + q + 3); + const float* weight_hc_N_3 = weight_hc.row(num_output * 2 + q + 3); - signed char* weight_xc_RUN = weight_xc_data_packed_dr.row(q / 4); - signed char* weight_hc_RUN = weight_hc_data_packed_dr.row(q / 4); - float* weight_xc_int8_descales_RUN = weight_xc_data_int8_descales_packed_dr.row(q / 4); - float* weight_hc_int8_descales_RUN = weight_hc_data_int8_descales_packed_dr.row(q / 4); + unsigned short* weight_xc_RUN = weight_xc_data_packed_dr.row(q / 4); + unsigned short* weight_hc_RUN = weight_hc_data_packed_dr.row(q / 4); for (int i = 0; i < size; i++) { - weight_xc_RUN[0] = weight_xc_R[i]; - weight_xc_RUN[1] = weight_xc_R_1[i]; - weight_xc_RUN[2] = weight_xc_R_2[i]; - weight_xc_RUN[3] = weight_xc_R_3[i]; - weight_xc_RUN[4] = weight_xc_U[i]; - weight_xc_RUN[5] = weight_xc_U_1[i]; - weight_xc_RUN[6] = weight_xc_U_2[i]; - weight_xc_RUN[7] = weight_xc_U_3[i]; + weight_xc_RUN[0] = float32_to_bfloat16(weight_xc_R[i]); + weight_xc_RUN[1] = float32_to_bfloat16(weight_xc_R_1[i]); + weight_xc_RUN[2] = float32_to_bfloat16(weight_xc_R_2[i]); + weight_xc_RUN[3] = float32_to_bfloat16(weight_xc_R_3[i]); + weight_xc_RUN[4] = float32_to_bfloat16(weight_xc_U[i]); + weight_xc_RUN[5] = float32_to_bfloat16(weight_xc_U_1[i]); + weight_xc_RUN[6] = float32_to_bfloat16(weight_xc_U_2[i]); + weight_xc_RUN[7] = float32_to_bfloat16(weight_xc_U_3[i]); weight_xc_RUN += 8; } for (int i = 0; i < num_output; i++) { - weight_hc_RUN[0] = weight_hc_R[i]; - weight_hc_RUN[1] = weight_hc_R_1[i]; - weight_hc_RUN[2] = weight_hc_R_2[i]; - weight_hc_RUN[3] = weight_hc_R_3[i]; - weight_hc_RUN[4] = weight_hc_U[i]; - weight_hc_RUN[5] = weight_hc_U_1[i]; - weight_hc_RUN[6] = weight_hc_U_2[i]; - weight_hc_RUN[7] = weight_hc_U_3[i]; + weight_hc_RUN[0] = float32_to_bfloat16(weight_hc_R[i]); + weight_hc_RUN[1] = float32_to_bfloat16(weight_hc_R_1[i]); + weight_hc_RUN[2] = float32_to_bfloat16(weight_hc_R_2[i]); + weight_hc_RUN[3] = float32_to_bfloat16(weight_hc_R_3[i]); + weight_hc_RUN[4] = float32_to_bfloat16(weight_hc_U[i]); + weight_hc_RUN[5] = float32_to_bfloat16(weight_hc_U_1[i]); + weight_hc_RUN[6] = float32_to_bfloat16(weight_hc_U_2[i]); + weight_hc_RUN[7] = float32_to_bfloat16(weight_hc_U_3[i]); weight_hc_RUN += 8; } for (int i = 0; i < size; i++) { - weight_xc_RUN[0] = weight_xc_N[i]; - weight_xc_RUN[1] = weight_xc_N_1[i]; - weight_xc_RUN[2] = weight_xc_N_2[i]; - weight_xc_RUN[3] = weight_xc_N_3[i]; + weight_xc_RUN[0] = float32_to_bfloat16(weight_xc_N[i]); + weight_xc_RUN[1] = float32_to_bfloat16(weight_xc_N_1[i]); + weight_xc_RUN[2] = float32_to_bfloat16(weight_xc_N_2[i]); + weight_xc_RUN[3] = float32_to_bfloat16(weight_xc_N_3[i]); weight_xc_RUN += 4; } for (int i = 0; i < num_output; i++) { - weight_hc_RUN[0] = weight_hc_N[i]; - weight_hc_RUN[1] = weight_hc_N_1[i]; - weight_hc_RUN[2] = weight_hc_N_2[i]; - weight_hc_RUN[3] = weight_hc_N_3[i]; + weight_hc_RUN[0] = float32_to_bfloat16(weight_hc_N[i]); + weight_hc_RUN[1] = float32_to_bfloat16(weight_hc_N_1[i]); + weight_hc_RUN[2] = float32_to_bfloat16(weight_hc_N_2[i]); + weight_hc_RUN[3] = float32_to_bfloat16(weight_hc_N_3[i]); weight_hc_RUN += 4; } - - float32x4_t _xc_R = vld1q_f32(weight_xc_int8_scales + q); - float32x4_t _xc_U = vld1q_f32(weight_xc_int8_scales + num_output + q); - float32x4_t _xc_N = vld1q_f32(weight_xc_int8_scales + num_output * 2 + q); - float32x4_t _hc_R = vld1q_f32(weight_hc_int8_scales + q); - float32x4_t _hc_U = vld1q_f32(weight_hc_int8_scales + num_output + q); - float32x4_t _hc_N = vld1q_f32(weight_hc_int8_scales + num_output * 2 + q); - -#if __aarch64__ - float32x4_t _one = vdupq_n_f32(1.f); - float32x4_t _reciprocal_xc_R = vdivq_f32(_one, _xc_R); - float32x4_t _reciprocal_xc_U = vdivq_f32(_one, _xc_U); - float32x4_t _reciprocal_xc_N = vdivq_f32(_one, _xc_N); - float32x4_t _reciprocal_hc_R = vdivq_f32(_one, _hc_R); - float32x4_t _reciprocal_hc_U = vdivq_f32(_one, _hc_U); - float32x4_t _reciprocal_hc_N = vdivq_f32(_one, _hc_N); -#else - float32x4_t _reciprocal_xc_R = vrecpeq_f32(_xc_R); - float32x4_t _reciprocal_xc_U = vrecpeq_f32(_xc_U); - float32x4_t _reciprocal_xc_N = vrecpeq_f32(_xc_N); - _reciprocal_xc_R = vmulq_f32(vrecpsq_f32(_xc_R, _reciprocal_xc_R), _reciprocal_xc_R); - _reciprocal_xc_U = vmulq_f32(vrecpsq_f32(_xc_U, _reciprocal_xc_U), _reciprocal_xc_U); - _reciprocal_xc_N = vmulq_f32(vrecpsq_f32(_xc_N, _reciprocal_xc_N), _reciprocal_xc_N); - float32x4_t _reciprocal_hc_R = vrecpeq_f32(_hc_R); - float32x4_t _reciprocal_hc_U = vrecpeq_f32(_hc_U); - float32x4_t _reciprocal_hc_N = vrecpeq_f32(_hc_N); - _reciprocal_hc_R = vmulq_f32(vrecpsq_f32(_hc_R, _reciprocal_hc_R), _reciprocal_hc_R); - _reciprocal_hc_U = vmulq_f32(vrecpsq_f32(_hc_U, _reciprocal_hc_U), _reciprocal_hc_U); - _reciprocal_hc_N = vmulq_f32(vrecpsq_f32(_hc_N, _reciprocal_hc_N), _reciprocal_hc_N); -#endif - - vst1q_f32(weight_xc_int8_descales_RUN, _reciprocal_xc_R); - vst1q_f32(weight_xc_int8_descales_RUN + 4, _reciprocal_xc_U); - vst1q_f32(weight_xc_int8_descales_RUN + 8, _reciprocal_xc_N); - - vst1q_f32(weight_hc_int8_descales_RUN, _reciprocal_hc_R); - vst1q_f32(weight_hc_int8_descales_RUN + 4, _reciprocal_hc_U); - vst1q_f32(weight_hc_int8_descales_RUN + 8, _reciprocal_hc_N); } #endif // __ARM_NEON for (; q < num_output; q++) { - bias_c_RUBNWN[0] = bias_c_R[q]; - bias_c_RUBNWN[1] = bias_c_U[q]; - bias_c_RUBNWN[2] = bias_c_BN[q]; - bias_c_RUBNWN[3] = bias_c_WN[q]; + bias_c_RUBNWN[0] = float32_to_bfloat16(bias_c_R[q]); + bias_c_RUBNWN[1] = float32_to_bfloat16(bias_c_U[q]); + bias_c_RUBNWN[2] = float32_to_bfloat16(bias_c_BN[q]); + bias_c_RUBNWN[3] = float32_to_bfloat16(bias_c_WN[q]); bias_c_RUBNWN += 4; - const signed char* weight_xc_R = weight_xc.row(num_output * 0 + q); - const signed char* weight_xc_U = weight_xc.row(num_output * 1 + q); - const signed char* weight_xc_N = weight_xc.row(num_output * 2 + q); + const float* weight_xc_R = weight_xc.row(num_output * 0 + q); + const float* weight_xc_U = weight_xc.row(num_output * 1 + q); + const float* weight_xc_N = weight_xc.row(num_output * 2 + q); - const signed char* weight_hc_R = weight_hc.row(num_output * 0 + q); - const signed char* weight_hc_U = weight_hc.row(num_output * 1 + q); - const signed char* weight_hc_N = weight_hc.row(num_output * 2 + q); + const float* weight_hc_R = weight_hc.row(num_output * 0 + q); + const float* weight_hc_U = weight_hc.row(num_output * 1 + q); + const float* weight_hc_N = weight_hc.row(num_output * 2 + q); #if __ARM_NEON - signed char* weight_xc_RUN = weight_xc_data_packed_dr.row(q / 4 + q % 4); - signed char* weight_hc_RUN = weight_hc_data_packed_dr.row(q / 4 + q % 4); - float* weight_xc_int8_descales_RUN = weight_xc_data_int8_descales_packed_dr.row(q / 4 + q % 4); - float* weight_hc_int8_descales_RUN = weight_hc_data_int8_descales_packed_dr.row(q / 4 + q % 4); + unsigned short* weight_xc_RUN = weight_xc_data_packed_dr.row(q / 4 + q % 4); + unsigned short* weight_hc_RUN = weight_hc_data_packed_dr.row(q / 4 + q % 4); #else - signed char* weight_xc_RUN = weight_xc_data_packed_dr.row(q); - signed char* weight_hc_RUN = weight_hc_data_packed_dr.row(q); - float* weight_xc_int8_descales_RUN = weight_xc_data_int8_descales_packed_dr.row(q); - float* weight_hc_int8_descales_RUN = weight_hc_data_int8_descales_packed_dr.row(q); + unsigned short* weight_xc_RUN = weight_xc_data_packed_dr.row(q); + unsigned short* weight_hc_RUN = weight_hc_data_packed_dr.row(q); #endif // __ARM_NEON for (int i = 0; i < size; i++) { - weight_xc_RUN[0] = weight_xc_R[i]; - weight_xc_RUN[1] = weight_xc_U[i]; + weight_xc_RUN[0] = float32_to_bfloat16(weight_xc_R[i]); + weight_xc_RUN[1] = float32_to_bfloat16(weight_xc_U[i]); weight_xc_RUN += 2; } for (int i = 0; i < num_output; i++) { - weight_hc_RUN[0] = weight_hc_R[i]; - weight_hc_RUN[1] = weight_hc_U[i]; + weight_hc_RUN[0] = float32_to_bfloat16(weight_hc_R[i]); + weight_hc_RUN[1] = float32_to_bfloat16(weight_hc_U[i]); weight_hc_RUN += 2; } for (int i = 0; i < size; i++) { - weight_xc_RUN[0] = weight_xc_N[i]; + weight_xc_RUN[0] = float32_to_bfloat16(weight_xc_N[i]); weight_xc_RUN += 1; } for (int i = 0; i < num_output; i++) { - weight_hc_RUN[0] = weight_hc_N[i]; + weight_hc_RUN[0] = float32_to_bfloat16(weight_hc_N[i]); weight_hc_RUN += 1; } - - weight_xc_int8_descales_RUN[0] = 1.f / weight_xc_int8_scales[num_output * 0 + q]; - weight_xc_int8_descales_RUN[1] = 1.f / weight_xc_int8_scales[num_output * 1 + q]; - weight_xc_int8_descales_RUN[2] = 1.f / weight_xc_int8_scales[num_output * 2 + q]; - - weight_hc_int8_descales_RUN[0] = 1.f / weight_hc_int8_scales[num_output * 0 + q]; - weight_hc_int8_descales_RUN[1] = 1.f / weight_hc_int8_scales[num_output * 1 + q]; - weight_hc_int8_descales_RUN[2] = 1.f / weight_hc_int8_scales[num_output * 2 + q]; } } @@ -1318,28 +1379,13 @@ int GRU_arm::create_pipeline_int8(const Option& opt) weight_xc_data.release(); bias_c_data.release(); weight_hc_data.release(); - weight_xc_data_int8_scales.release(); - weight_hc_data_int8_scales.release(); } return 0; } -#endif // NCNN_INT8 -int GRU_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +int GRU_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { - int elembits = bottom_blob.elembits(); - -#if NCNN_ARM82 - if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - return forward_fp16s(bottom_blob, top_blob, opt); -#endif - -#if NCNN_BF16 - if (opt.use_bf16_storage && elembits == 16) - return forward_bf16s(bottom_blob, top_blob, opt); -#endif - int T = bottom_blob.h; int num_directions = direction == 2 ? 2 : 1; @@ -1350,67 +1396,38 @@ int GRU_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c return -100; hidden.fill(0.f); - top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); if (top_blob.empty()) return -100; // Uni directional if (direction == 0 || direction == 1) { -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = gru(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } + int ret = gru_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; } if (direction == 2) { - Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); if (top_blob_forward.empty()) return -100; - Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); if (top_blob_reverse.empty()) return -100; -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif { - int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + int ret = gru_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); if (ret != 0) return ret; } - hidden.fill(0.0f); + hidden.fill(0.f); -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif { - int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + int ret = gru_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); if (ret != 0) return ret; } @@ -1418,33 +1435,21 @@ int GRU_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c // concat w for (int i = 0; i < T; i++) { - const float* pf = top_blob_forward.row(i); - const float* pr = top_blob_reverse.row(i); - float* ptr = top_blob.row(i); + const unsigned short* pf = top_blob_forward.row(i); + const unsigned short* pr = top_blob_reverse.row(i); + unsigned short* ptr = top_blob.row(i); - memcpy(ptr, pf, num_output * sizeof(float)); - memcpy(ptr + num_output, pr, num_output * sizeof(float)); + memcpy(ptr, pf, num_output * sizeof(unsigned short)); + memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short)); } } return 0; } -int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +int GRU_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { const Mat& bottom_blob = bottom_blobs[0]; - int elembits = bottom_blob.elembits(); - -#if NCNN_ARM82 - if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - return forward_fp16s(bottom_blobs, top_blobs, opt); -#endif - -#if NCNN_BF16 - if (opt.use_bf16_storage && elembits == 16) - return forward_bf16s(bottom_blobs, top_blobs, opt); -#endif - int T = bottom_blob.h; int num_directions = direction == 2 ? 2 : 1; @@ -1452,7 +1457,9 @@ int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; if (bottom_blobs.size() == 2) { - hidden = bottom_blobs[1].clone(hidden_allocator); + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_allocator; + cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt_cast); } else { @@ -1463,67 +1470,38 @@ int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top } Mat& top_blob = top_blobs[0]; - top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); if (top_blob.empty()) return -100; // Uni directional if (direction == 0 || direction == 1) { -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = gru(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } + int ret = gru_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; } if (direction == 2) { - Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); if (top_blob_forward.empty()) return -100; - Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); if (top_blob_reverse.empty()) return -100; Mat hidden0 = hidden.row_range(0, 1); -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden0, opt); - if (ret != 0) - return ret; - } - else -#endif { - int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); + int ret = gru_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); if (ret != 0) return ret; } Mat hidden1 = hidden.row_range(1, 1); -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), hidden1, opt); - if (ret != 0) - return ret; - } - else -#endif { - int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); + int ret = gru_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); if (ret != 0) return ret; } @@ -1531,1049 +1509,142 @@ int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top // concat w for (int i = 0; i < T; i++) { - const float* pf = top_blob_forward.row(i); - const float* pr = top_blob_reverse.row(i); - float* ptr = top_blob.row(i); + const unsigned short* pf = top_blob_forward.row(i); + const unsigned short* pr = top_blob_reverse.row(i); + unsigned short* ptr = top_blob.row(i); - memcpy(ptr, pf, num_output * sizeof(float)); - memcpy(ptr + num_output, pr, num_output * sizeof(float)); + memcpy(ptr, pf, num_output * sizeof(unsigned short)); + memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short)); } } if (top_blobs.size() == 2) { - top_blobs[1] = hidden; + cast_float32_to_bfloat16(hidden, top_blobs[1], opt); } return 0; } +#endif // NCNN_BF16 -#if NCNN_BF16 -static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) +#if NCNN_INT8 +int GRU_arm::create_pipeline_int8(const Option& opt) { - int size = bottom_blob.w; - int T = bottom_blob.h; - - int num_output = top_blob.w; + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / num_output / 3; - // 2 x num_output -#if __ARM_NEON - Mat gates(4 * 2, num_output / 4 + num_output % 4, 4u, opt.workspace_allocator); -#else - Mat gates(2, num_output, 4u, opt.workspace_allocator); -#endif - if (gates.empty()) - return -100; + gru_transform_weight_int8(weight_xc_data, weight_xc_data_int8_scales, weight_hc_data, weight_hc_data_int8_scales, bias_c_data, weight_data_tm, weight_data_tm_int8_descales, bias_c_data_packed, size, num_output, num_directions, opt); - // unroll - for (int t = 0; t < T; t++) + if (opt.lightmode) { - int ti = reverse ? T - 1 - t : t; - - int remain_num_output_start = 0; -#if __ARM_NEON - int nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const unsigned short* x = bottom_blob.row(ti); - - // gate reset update - const unsigned short* bias_c_RUBNWN = (const unsigned short*)bias_c + q * 4; - - const unsigned short* weight_xc_RUN = weight_xc.row(q / 4); - const unsigned short* weight_hc_RUN = weight_hc.row(q / 4); - - float32x4_t _gru_R = bfloat2float(vld1_u16(bias_c_RUBNWN)); - float32x4_t _gru_U = bfloat2float(vld1_u16(bias_c_RUBNWN + 4)); - float32x4_t _sum1 = vdupq_n_f32(0.f); - float32x4_t _sum2 = vdupq_n_f32(0.f); - float32x4_t _sum3 = vdupq_n_f32(0.f); - float32x4_t _sum4 = vdupq_n_f32(0.f); - float32x4_t _sum5 = vdupq_n_f32(0.f); - float32x4_t _sum6 = vdupq_n_f32(0.f); - - int i = 0; - for (; i + 3 < size; i += 4) - { - float32x4_t _xi = bfloat2float(vld1_u16(x + i)); - float32x4_t _weight_xc_R = bfloat2float(vld1_u16(weight_xc_RUN)); - float32x4_t _weight_xc_U = bfloat2float(vld1_u16(weight_xc_RUN + 4)); - float32x4_t _weight_xc_R_1 = bfloat2float(vld1_u16(weight_xc_RUN + 8)); - float32x4_t _weight_xc_U_1 = bfloat2float(vld1_u16(weight_xc_RUN + 12)); - float32x4_t _weight_xc_R_2 = bfloat2float(vld1_u16(weight_xc_RUN + 16)); - float32x4_t _weight_xc_U_2 = bfloat2float(vld1_u16(weight_xc_RUN + 20)); - float32x4_t _weight_xc_R_3 = bfloat2float(vld1_u16(weight_xc_RUN + 24)); - float32x4_t _weight_xc_U_3 = bfloat2float(vld1_u16(weight_xc_RUN + 28)); -#if __aarch64__ - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_xc_U_2, _xi, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); -#else - _gru_R = vmlaq_lane_f32(_gru_R, _weight_xc_R, vget_low_f32(_xi), 0); - _gru_U = vmlaq_lane_f32(_gru_U, _weight_xc_U, vget_low_f32(_xi), 0); - _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_R_1, vget_low_f32(_xi), 1); - _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_U_1, vget_low_f32(_xi), 1); - _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_R_2, vget_high_f32(_xi), 0); - _sum4 = vmlaq_lane_f32(_sum4, _weight_xc_U_2, vget_high_f32(_xi), 0); - _sum5 = vmlaq_lane_f32(_sum5, _weight_xc_R_3, vget_high_f32(_xi), 1); - _sum6 = vmlaq_lane_f32(_sum6, _weight_xc_U_3, vget_high_f32(_xi), 1); -#endif - - weight_xc_RUN += 32; - } - for (; i < size; i++) - { - unsigned short xi = x[i]; - - float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); - float32x4_t _weight_xc_R = bfloat2float(vld1_u16(weight_xc_RUN)); - float32x4_t _weight_xc_U = bfloat2float(vld1_u16(weight_xc_RUN + 4)); - _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); - _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); - - weight_xc_RUN += 8; - } - - i = 0; - for (; i + 3 < num_output; i += 4) - { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - float32x4_t _weight_hc_R = bfloat2float(vld1_u16(weight_hc_RUN)); - float32x4_t _weight_hc_U = bfloat2float(vld1_u16(weight_hc_RUN + 4)); - float32x4_t _weight_hc_R_1 = bfloat2float(vld1_u16(weight_hc_RUN + 8)); - float32x4_t _weight_hc_U_1 = bfloat2float(vld1_u16(weight_hc_RUN + 12)); - float32x4_t _weight_hc_R_2 = bfloat2float(vld1_u16(weight_hc_RUN + 16)); - float32x4_t _weight_hc_U_2 = bfloat2float(vld1_u16(weight_hc_RUN + 20)); - float32x4_t _weight_hc_R_3 = bfloat2float(vld1_u16(weight_hc_RUN + 24)); - float32x4_t _weight_hc_U_3 = bfloat2float(vld1_u16(weight_hc_RUN + 28)); -#if __aarch64__ - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_hc_U_2, _h_cont, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); -#else - _gru_R = vmlaq_lane_f32(_gru_R, _weight_hc_R, vget_low_f32(_h_cont), 0); - _gru_U = vmlaq_lane_f32(_gru_U, _weight_hc_U, vget_low_f32(_h_cont), 0); - _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_R_1, vget_low_f32(_h_cont), 1); - _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_U_1, vget_low_f32(_h_cont), 1); - _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_R_2, vget_high_f32(_h_cont), 0); - _sum4 = vmlaq_lane_f32(_sum4, _weight_hc_U_2, vget_high_f32(_h_cont), 0); - _sum5 = vmlaq_lane_f32(_sum5, _weight_hc_R_3, vget_high_f32(_h_cont), 1); - _sum6 = vmlaq_lane_f32(_sum6, _weight_hc_U_3, vget_high_f32(_h_cont), 1); -#endif - - weight_hc_RUN += 32; - } - for (; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_R = bfloat2float(vld1_u16(weight_hc_RUN)); - float32x4_t _weight_hc_U = bfloat2float(vld1_u16(weight_hc_RUN + 4)); - _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); - _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); - - weight_hc_RUN += 8; - } - - _gru_R = vaddq_f32(_gru_R, _sum1); - _gru_U = vaddq_f32(_gru_U, _sum2); - _sum3 = vaddq_f32(_sum3, _sum5); - _sum4 = vaddq_f32(_sum4, _sum6); - _gru_R = vaddq_f32(_gru_R, _sum3); - _gru_U = vaddq_f32(_gru_U, _sum4); - - // sigmoid(R) - // sigmoid(U) - _gru_R = sigmoid_ps(_gru_R); - _gru_U = sigmoid_ps(_gru_U); - - // gate new - float32x4_t _gru_N = bfloat2float(vld1_u16(bias_c_RUBNWN + 8)); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); - - i = 0; - for (; i + 3 < num_output; i += 4) - { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - float32x4_t _weight_hc_N = bfloat2float(vld1_u16(weight_hc_RUN)); - float32x4_t _weight_hc_N_1 = bfloat2float(vld1_u16(weight_hc_RUN + 4)); - float32x4_t _weight_hc_N_2 = bfloat2float(vld1_u16(weight_hc_RUN + 8)); - float32x4_t _weight_hc_N_3 = bfloat2float(vld1_u16(weight_hc_RUN + 12)); -#if __aarch64__ - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); -#else - _gru_N = vmlaq_lane_f32(_gru_N, _weight_hc_N, vget_low_f32(_h_cont), 0); - _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_N_1, vget_low_f32(_h_cont), 1); - _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_N_2, vget_high_f32(_h_cont), 0); - _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_N_3, vget_high_f32(_h_cont), 1); -#endif + weight_xc_data.release(); + weight_hc_data.release(); + bias_c_data.release(); + weight_xc_data_int8_scales.release(); + weight_hc_data_int8_scales.release(); + } - weight_hc_RUN += 16; - } - for (; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_N = bfloat2float(vld1_u16(weight_hc_RUN)); - _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); - - weight_hc_RUN += 4; - } - - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); - - _gru_N = vmlaq_f32(bfloat2float(vld1_u16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); - - i = 0; - for (; i + 3 < size; i += 4) - { - float32x4_t _xi = bfloat2float(vld1_u16(x + i)); - float32x4_t _weight_xc_N = bfloat2float(vld1_u16(weight_xc_RUN)); - float32x4_t _weight_xc_N_1 = bfloat2float(vld1_u16(weight_xc_RUN + 4)); - float32x4_t _weight_xc_N_2 = bfloat2float(vld1_u16(weight_xc_RUN + 8)); - float32x4_t _weight_xc_N_3 = bfloat2float(vld1_u16(weight_xc_RUN + 12)); -#if __aarch64__ - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); -#else - _gru_N = vmlaq_lane_f32(_gru_N, _weight_xc_N, vget_low_f32(_xi), 0); - _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_N_1, vget_low_f32(_xi), 1); - _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_N_2, vget_high_f32(_xi), 0); - _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_N_3, vget_high_f32(_xi), 1); -#endif - - weight_xc_RUN += 16; - } - for (; i < size; i++) - { - unsigned short xi = x[i]; - - float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); - float32x4_t _weight_xc_N = bfloat2float(vld1_u16(weight_xc_RUN)); - _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); - - weight_xc_RUN += 4; - } - - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); - - // tanh(N) - _gru_N = tanh_ps(_gru_N); - - float* gates_data = gates.row(q / 4); - - vst1q_f32(gates_data, _gru_U); - vst1q_f32(gates_data + 4, _gru_N); - } -#endif // __ARM_NEON - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const unsigned short* x = bottom_blob.row(ti); - - // gate reset update - const unsigned short* bias_c_RUBNWN = (const unsigned short*)bias_c + q * 4; - -#if __ARM_NEON - const unsigned short* weight_xc_RUN = weight_xc.row(q / 4 + q % 4); - const unsigned short* weight_hc_RUN = weight_hc.row(q / 4 + q % 4); -#else - const unsigned short* weight_xc_RUN = weight_xc.row(q); - const unsigned short* weight_hc_RUN = weight_hc.row(q); -#endif - - float R = bfloat16_to_float32(bias_c_RUBNWN[0]); - float U = bfloat16_to_float32(bias_c_RUBNWN[1]); - - for (int i = 0; i < size; i++) - { - float xi = bfloat16_to_float32(x[i]); - - R += bfloat16_to_float32(weight_xc_RUN[0]) * xi; - U += bfloat16_to_float32(weight_xc_RUN[1]) * xi; - - weight_xc_RUN += 2; - } - - for (int i = 0; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - R += bfloat16_to_float32(weight_hc_RUN[0]) * h_cont; - U += bfloat16_to_float32(weight_hc_RUN[1]) * h_cont; - - weight_hc_RUN += 2; - } - - // sigmoid(R) - // sigmoid(U) - R = 1.f / (1.f + expf(-R)); - U = 1.f / (1.f + expf(-U)); - - // gate new - float N = bfloat16_to_float32(bias_c_RUBNWN[2]); - - for (int i = 0; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - N += bfloat16_to_float32(weight_hc_RUN[0]) * h_cont; - - weight_hc_RUN += 1; - } - - N = bfloat16_to_float32(bias_c_RUBNWN[3]) + R * N; - - for (int i = 0; i < size; i++) - { - float xi = bfloat16_to_float32(x[i]); - - N += bfloat16_to_float32(weight_xc_RUN[0]) * xi; - - weight_xc_RUN += 1; - } - - // tanh(N) - N = tanhf(N); - -#if __ARM_NEON - float* gates_data = gates.row(q / 4 + q % 4); -#else - float* gates_data = gates.row(q); -#endif - - gates_data[0] = U; - gates_data[1] = N; - } - - // h_t := (1 - update) .* new + update .* h_{t-1} - unsigned short* output_data = top_blob.row(ti); - - float* hidden_ptr = hidden_state; - -#if __ARM_NEON - nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const float* gates_data = gates.row(q / 4); - - float32x4_t _gru_U = vld1q_f32(gates_data); - float32x4_t _gru_N = vld1q_f32(gates_data + 4); - - float32x4_t _gru_H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U), _gru_N), vmulq_f32(_gru_U, vld1q_f32(hidden_ptr + q))); - - vst1q_f32(hidden_ptr + q, _gru_H); - vst1_u16(output_data + q, float2bfloat(_gru_H)); - } -#endif // __ARM_NEON - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { -#if __ARM_NEON - const float* gates_data = gates.row(q / 4 + q % 4); -#else - const float* gates_data = gates.row(q); -#endif - - float U = gates_data[0]; - float N = gates_data[1]; - - float H = (1 - U) * N + U * hidden_ptr[q]; - - hidden_ptr[q] = H; - output_data[q] = float32_to_bfloat16(H); - } - } - - return 0; -} - -#if NCNN_INT8 -static int gru_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) -{ - int size = bottom_blob.w; - int T = bottom_blob.h; - - int num_output = top_blob.w; - - // 2 x num_output -#if __ARM_NEON - Mat gates(4 * 2, num_output / 4 + num_output % 4, 4u, opt.workspace_allocator); -#else - Mat gates(2, num_output, 4u, opt.workspace_allocator); -#endif - if (gates.empty()) - return -100; - - // unroll - for (int t = 0; t < T; t++) - { - int ti = reverse ? T - 1 - t : t; - - int remain_num_output_start = 0; -#if __ARM_NEON - int nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const unsigned short* x = bottom_blob.row(ti); - - // gate reset update - const unsigned short* bias_c_RUBNWN = (const unsigned short*)bias_c + q * 4; - - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4); - - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4); - - float32x4_t _descale_xc_R = vld1q_f32(weight_xc_int8_descales_RUN); - float32x4_t _descale_xc_U = vld1q_f32(weight_xc_int8_descales_RUN + 4); - float32x4_t _descale_hc_R = vld1q_f32(weight_hc_int8_descales_RUN); - float32x4_t _descale_hc_U = vld1q_f32(weight_hc_int8_descales_RUN + 4); - - float32x4_t _gru_R = bfloat2float(vld1_u16(bias_c_RUBNWN)); - float32x4_t _gru_U = bfloat2float(vld1_u16(bias_c_RUBNWN + 4)); - float32x4_t _sum1 = vdupq_n_f32(0.f); - float32x4_t _sum2 = vdupq_n_f32(0.f); - float32x4_t _sum3 = vdupq_n_f32(0.f); - float32x4_t _sum4 = vdupq_n_f32(0.f); - float32x4_t _sum5 = vdupq_n_f32(0.f); - float32x4_t _sum6 = vdupq_n_f32(0.f); - - int i = 0; - for (; i + 3 < size; i += 4) - { - float32x4_t _xi = bfloat2float(vld1_u16(x + i)); - - int8x16_t _weight_xc_RU01 = vld1q_s8(weight_xc_int8_RUN); - int8x16_t _weight_xc_RU23 = vld1q_s8(weight_xc_int8_RUN + 16); - - int16x8_t _weight_xc_RU0 = vmovl_s8(vget_low_s8(_weight_xc_RU01)); - int16x8_t _weight_xc_RU1 = vmovl_s8(vget_high_s8(_weight_xc_RU01)); - int16x8_t _weight_xc_RU2 = vmovl_s8(vget_low_s8(_weight_xc_RU23)); - int16x8_t _weight_xc_RU3 = vmovl_s8(vget_high_s8(_weight_xc_RU23)); - - float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU0))), _descale_xc_R); - float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU0))), _descale_xc_U); - float32x4_t _weight_xc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU1))), _descale_xc_R); - float32x4_t _weight_xc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU1))), _descale_xc_U); - float32x4_t _weight_xc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU2))), _descale_xc_R); - float32x4_t _weight_xc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU2))), _descale_xc_U); - float32x4_t _weight_xc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU3))), _descale_xc_R); - float32x4_t _weight_xc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU3))), _descale_xc_U); - -#if __aarch64__ - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_xc_U_2, _xi, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); -#else - _gru_R = vmlaq_lane_f32(_gru_R, _weight_xc_R, vget_low_f32(_xi), 0); - _gru_U = vmlaq_lane_f32(_gru_U, _weight_xc_U, vget_low_f32(_xi), 0); - _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_R_1, vget_low_f32(_xi), 1); - _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_U_1, vget_low_f32(_xi), 1); - _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_R_2, vget_high_f32(_xi), 0); - _sum4 = vmlaq_lane_f32(_sum4, _weight_xc_U_2, vget_high_f32(_xi), 0); - _sum5 = vmlaq_lane_f32(_sum5, _weight_xc_R_3, vget_high_f32(_xi), 1); - _sum6 = vmlaq_lane_f32(_sum6, _weight_xc_U_3, vget_high_f32(_xi), 1); -#endif - - weight_xc_int8_RUN += 32; - } - for (; i < size; i++) - { - unsigned short xi = x[i]; - - float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); - - int16x8_t _weight_xc_RU = vmovl_s8(vld1_s8(weight_xc_int8_RUN)); - float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU))), _descale_xc_R); - float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU))), _descale_xc_U); - - _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); - _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); - - weight_xc_int8_RUN += 8; - } - - i = 0; - for (; i + 3 < num_output; i += 4) - { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - - int8x16_t _weight_hc_RU01 = vld1q_s8(weight_hc_int8_RUN); - int8x16_t _weight_hc_RU23 = vld1q_s8(weight_hc_int8_RUN + 16); - - int16x8_t _weight_hc_RU0 = vmovl_s8(vget_low_s8(_weight_hc_RU01)); - int16x8_t _weight_hc_RU1 = vmovl_s8(vget_high_s8(_weight_hc_RU01)); - int16x8_t _weight_hc_RU2 = vmovl_s8(vget_low_s8(_weight_hc_RU23)); - int16x8_t _weight_hc_RU3 = vmovl_s8(vget_high_s8(_weight_hc_RU23)); - - float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU0))), _descale_hc_R); - float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU0))), _descale_hc_U); - float32x4_t _weight_hc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU1))), _descale_hc_R); - float32x4_t _weight_hc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU1))), _descale_hc_U); - float32x4_t _weight_hc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU2))), _descale_hc_R); - float32x4_t _weight_hc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU2))), _descale_hc_U); - float32x4_t _weight_hc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU3))), _descale_hc_R); - float32x4_t _weight_hc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU3))), _descale_hc_U); - -#if __aarch64__ - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_hc_U_2, _h_cont, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); -#else - _gru_R = vmlaq_lane_f32(_gru_R, _weight_hc_R, vget_low_f32(_h_cont), 0); - _gru_U = vmlaq_lane_f32(_gru_U, _weight_hc_U, vget_low_f32(_h_cont), 0); - _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_R_1, vget_low_f32(_h_cont), 1); - _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_U_1, vget_low_f32(_h_cont), 1); - _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_R_2, vget_high_f32(_h_cont), 0); - _sum4 = vmlaq_lane_f32(_sum4, _weight_hc_U_2, vget_high_f32(_h_cont), 0); - _sum5 = vmlaq_lane_f32(_sum5, _weight_hc_R_3, vget_high_f32(_h_cont), 1); - _sum6 = vmlaq_lane_f32(_sum6, _weight_hc_U_3, vget_high_f32(_h_cont), 1); -#endif - - weight_hc_int8_RUN += 32; - } - for (; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - float32x4_t _h_cont = vdupq_n_f32(h_cont); - - int16x8_t _weight_hc_RU = vmovl_s8(vld1_s8(weight_hc_int8_RUN)); - float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU))), _descale_hc_R); - float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU))), _descale_hc_U); - - _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); - _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); - - weight_hc_int8_RUN += 8; - } - - _gru_R = vaddq_f32(_gru_R, _sum1); - _gru_U = vaddq_f32(_gru_U, _sum2); - _sum3 = vaddq_f32(_sum3, _sum5); - _sum4 = vaddq_f32(_sum4, _sum6); - _gru_R = vaddq_f32(_gru_R, _sum3); - _gru_U = vaddq_f32(_gru_U, _sum4); - - // sigmoid(R) - // sigmoid(U) - _gru_R = sigmoid_ps(_gru_R); - _gru_U = sigmoid_ps(_gru_U); - - // gate new - float32x4_t _gru_N = bfloat2float(vld1_u16(bias_c_RUBNWN + 8)); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); - - float32x4_t _descale_xc_N = vld1q_f32(weight_xc_int8_descales_RUN + 8); - float32x4_t _descale_hc_N = vld1q_f32(weight_hc_int8_descales_RUN + 8); - - i = 0; - for (; i + 3 < num_output; i += 4) - { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - - int8x16_t _weight_hc_N0123 = vld1q_s8(weight_hc_int8_RUN); - int16x8_t _weight_hc_N01 = vmovl_s8(vget_low_s8(_weight_hc_N0123)); - int16x8_t _weight_hc_N23 = vmovl_s8(vget_high_s8(_weight_hc_N0123)); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N01))), _descale_hc_N); - float32x4_t _weight_hc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N01))), _descale_hc_N); - float32x4_t _weight_hc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N23))), _descale_hc_N); - float32x4_t _weight_hc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N23))), _descale_hc_N); - -#if __aarch64__ - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); -#else - _gru_N = vmlaq_lane_f32(_gru_N, _weight_hc_N, vget_low_f32(_h_cont), 0); - _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_N_1, vget_low_f32(_h_cont), 1); - _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_N_2, vget_high_f32(_h_cont), 0); - _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_N_3, vget_high_f32(_h_cont), 1); -#endif - - weight_hc_int8_RUN += 16; - } - for (; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - float32x4_t _h_cont = vdupq_n_f32(h_cont); - int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0])); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc_N); - _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); - - weight_hc_int8_RUN += 4; - } - - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); - - _gru_N = vmlaq_f32(bfloat2float(vld1_u16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); - - i = 0; - for (; i + 3 < size; i += 4) - { - float32x4_t _xi = bfloat2float(vld1_u16(x + i)); - - int8x16_t _weight_xc_N0123 = vld1q_s8(weight_xc_int8_RUN); - int16x8_t _weight_xc_N01 = vmovl_s8(vget_low_s8(_weight_xc_N0123)); - int16x8_t _weight_xc_N23 = vmovl_s8(vget_high_s8(_weight_xc_N0123)); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N01))), _descale_xc_N); - float32x4_t _weight_xc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N01))), _descale_xc_N); - float32x4_t _weight_xc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N23))), _descale_xc_N); - float32x4_t _weight_xc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N23))), _descale_xc_N); - -#if __aarch64__ - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); -#else - _gru_N = vmlaq_lane_f32(_gru_N, _weight_xc_N, vget_low_f32(_xi), 0); - _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_N_1, vget_low_f32(_xi), 1); - _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_N_2, vget_high_f32(_xi), 0); - _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_N_3, vget_high_f32(_xi), 1); -#endif - - weight_xc_int8_RUN += 16; - } - for (; i < size; i++) - { - unsigned short xi = x[i]; - - float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); - int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0])); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc_N); - _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); - - weight_xc_int8_RUN += 4; - } - - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); - - // tanh(N) - _gru_N = tanh_ps(_gru_N); - - float* gates_data = gates.row(q / 4); - - vst1q_f32(gates_data, _gru_U); - vst1q_f32(gates_data + 4, _gru_N); - } -#endif // __ARM_NEON - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const unsigned short* x = bottom_blob.row(ti); - - // gate reset update - const unsigned short* bias_c_RUBNWN = (const unsigned short*)bias_c + q * 4; - -#if __ARM_NEON - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4 + q % 4); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4 + q % 4); - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4 + q % 4); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4 + q % 4); -#else - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q); - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q); -#endif - - const float descale_xc_R = weight_xc_int8_descales_RUN[0]; - const float descale_xc_U = weight_xc_int8_descales_RUN[1]; - const float descale_xc_N = weight_xc_int8_descales_RUN[2]; - - const float descale_hc_R = weight_hc_int8_descales_RUN[0]; - const float descale_hc_U = weight_hc_int8_descales_RUN[1]; - const float descale_hc_N = weight_hc_int8_descales_RUN[2]; - - float R = bfloat16_to_float32(bias_c_RUBNWN[0]); - float U = bfloat16_to_float32(bias_c_RUBNWN[1]); - - for (int i = 0; i < size; i++) - { - float xi = bfloat16_to_float32(x[i]); - - R += weight_xc_int8_RUN[0] * descale_xc_R * xi; - U += weight_xc_int8_RUN[1] * descale_xc_U * xi; - - weight_xc_int8_RUN += 2; - } - - for (int i = 0; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - R += weight_hc_int8_RUN[0] * descale_hc_R * h_cont; - U += weight_hc_int8_RUN[1] * descale_hc_U * h_cont; - - weight_hc_int8_RUN += 2; - } - - // sigmoid(R) - // sigmoid(U) - R = 1.f / (1.f + expf(-R)); - U = 1.f / (1.f + expf(-U)); - - // gate new - float N = bfloat16_to_float32(bias_c_RUBNWN[2]); - - for (int i = 0; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - N += weight_hc_int8_RUN[0] * descale_hc_N * h_cont; - - weight_hc_int8_RUN += 1; - } - - N = bfloat16_to_float32(bias_c_RUBNWN[3]) + R * N; - - for (int i = 0; i < size; i++) - { - float xi = bfloat16_to_float32(x[i]); - - N += weight_xc_int8_RUN[0] * descale_xc_N * xi; - - weight_xc_int8_RUN += 1; - } - - // tanh(N) - N = tanhf(N); - -#if __ARM_NEON - float* gates_data = gates.row(q / 4 + q % 4); -#else - float* gates_data = gates.row(q); -#endif - - gates_data[0] = U; - gates_data[1] = N; - } - - // h_t := (1 - update) .* new + update .* h_{t-1} - unsigned short* output_data = top_blob.row(ti); - - float* hidden_ptr = hidden_state; - -#if __ARM_NEON - nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const float* gates_data = gates.row(q / 4); - - float32x4_t _gru_U = vld1q_f32(gates_data); - float32x4_t _gru_N = vld1q_f32(gates_data + 4); - - float32x4_t _gru_H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U), _gru_N), vmulq_f32(_gru_U, vld1q_f32(hidden_ptr + q))); - - vst1q_f32(hidden_ptr + q, _gru_H); - vst1_u16(output_data + q, float2bfloat(_gru_H)); - } -#endif // __ARM_NEON - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { -#if __ARM_NEON - const float* gates_data = gates.row(q / 4 + q % 4); -#else - const float* gates_data = gates.row(q); -#endif - - float U = gates_data[0]; - float N = gates_data[1]; - - float H = (1 - U) * N + U * hidden_ptr[q]; - - hidden_ptr[q] = H; - output_data[q] = float32_to_bfloat16(H); - } - } - - return 0; -} -#endif // NCNN_INT8 - -int GRU_arm::create_pipeline_bf16s(const Option& opt) -{ -#if NCNN_INT8 - if (int8_scale_term) - { - create_pipeline_int8(opt); - - ncnn::Mat tmp; - cast_float32_to_bfloat16(bias_c_data_packed, tmp, opt); - bias_c_data_packed = tmp; - - return 0; - } -#endif - - // pack RUN - int num_directions = direction == 2 ? 2 : 1; - int size = weight_data_size / num_directions / num_output / 3; - -#if __ARM_NEON - weight_xc_data_packed.create(size * 12, num_output / 4 + num_output % 4, num_directions, 2u, 1); - bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); - weight_hc_data_packed.create(num_output * 12, num_output / 4 + num_output % 4, num_directions, 2u, 1); -#else - weight_xc_data_packed.create(size * 3, num_output, num_directions, 2u, 1); - bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); - weight_hc_data_packed.create(num_output * 3, num_output, num_directions, 2u, 1); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int dr = 0; dr < num_directions; dr++) - { - const Mat weight_xc = weight_xc_data.channel(dr); - const Mat bias_c = bias_c_data.channel(dr); - const Mat weight_hc = weight_hc_data.channel(dr); - - Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); - Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr); - Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); - - const float* bias_c_R = bias_c.row(0); - const float* bias_c_U = bias_c.row(1); - const float* bias_c_WN = bias_c.row(2); - const float* bias_c_BN = bias_c.row(3); - - unsigned short* bias_c_RUBNWN = bias_c_data_packed_dr.row(0); - - int q = 0; -#if __ARM_NEON - for (; q + 3 < num_output; q += 4) - { - vst1_u16(bias_c_RUBNWN, float2bfloat(vld1q_f32(bias_c_R + q))); - vst1_u16(bias_c_RUBNWN + 4, float2bfloat(vld1q_f32(bias_c_U + q))); - vst1_u16(bias_c_RUBNWN + 8, float2bfloat(vld1q_f32(bias_c_BN + q))); - vst1_u16(bias_c_RUBNWN + 12, float2bfloat(vld1q_f32(bias_c_WN + q))); - - bias_c_RUBNWN += 16; - - const float* weight_xc_R = weight_xc.row(num_output * 0 + q); - const float* weight_xc_U = weight_xc.row(num_output * 1 + q); - const float* weight_xc_N = weight_xc.row(num_output * 2 + q); - - const float* weight_xc_R_1 = weight_xc.row(num_output * 0 + q + 1); - const float* weight_xc_U_1 = weight_xc.row(num_output * 1 + q + 1); - const float* weight_xc_N_1 = weight_xc.row(num_output * 2 + q + 1); - - const float* weight_xc_R_2 = weight_xc.row(num_output * 0 + q + 2); - const float* weight_xc_U_2 = weight_xc.row(num_output * 1 + q + 2); - const float* weight_xc_N_2 = weight_xc.row(num_output * 2 + q + 2); - - const float* weight_xc_R_3 = weight_xc.row(num_output * 0 + q + 3); - const float* weight_xc_U_3 = weight_xc.row(num_output * 1 + q + 3); - const float* weight_xc_N_3 = weight_xc.row(num_output * 2 + q + 3); - - const float* weight_hc_R = weight_hc.row(num_output * 0 + q); - const float* weight_hc_U = weight_hc.row(num_output * 1 + q); - const float* weight_hc_N = weight_hc.row(num_output * 2 + q); + return 0; +} - const float* weight_hc_R_1 = weight_hc.row(num_output * 0 + q + 1); - const float* weight_hc_U_1 = weight_hc.row(num_output * 1 + q + 1); - const float* weight_hc_N_1 = weight_hc.row(num_output * 2 + q + 1); +void GRU_arm::dynamic_quantize(const Mat& bottom_blob, int elemtype, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const +{ + int size = bottom_blob.w; + int T = bottom_blob.h; - const float* weight_hc_R_2 = weight_hc.row(num_output * 0 + q + 2); - const float* weight_hc_U_2 = weight_hc.row(num_output * 1 + q + 2); - const float* weight_hc_N_2 = weight_hc.row(num_output * 2 + q + 2); + // dynamic quantize bottom_blob + bottom_blob_int8_descales.create(T, (size_t)4u, 1, opt.blob_allocator); - const float* weight_hc_R_3 = weight_hc.row(num_output * 0 + q + 3); - const float* weight_hc_U_3 = weight_hc.row(num_output * 1 + q + 3); - const float* weight_hc_N_3 = weight_hc.row(num_output * 2 + q + 3); + Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.blob_allocator); - unsigned short* weight_xc_RUN = weight_xc_data_packed_dr.row(q / 4); - unsigned short* weight_hc_RUN = weight_hc_data_packed_dr.row(q / 4); + if (elemtype == 1) + { + // fp32 + for (int t = 0; t < T; t++) + { + const float* x = bottom_blob.row(t); + float absmax = 0.f; for (int i = 0; i < size; i++) { - weight_xc_RUN[0] = float32_to_bfloat16(weight_xc_R[i]); - weight_xc_RUN[1] = float32_to_bfloat16(weight_xc_R_1[i]); - weight_xc_RUN[2] = float32_to_bfloat16(weight_xc_R_2[i]); - weight_xc_RUN[3] = float32_to_bfloat16(weight_xc_R_3[i]); - weight_xc_RUN[4] = float32_to_bfloat16(weight_xc_U[i]); - weight_xc_RUN[5] = float32_to_bfloat16(weight_xc_U_1[i]); - weight_xc_RUN[6] = float32_to_bfloat16(weight_xc_U_2[i]); - weight_xc_RUN[7] = float32_to_bfloat16(weight_xc_U_3[i]); - - weight_xc_RUN += 8; + absmax = std::max(absmax, (float)fabs(x[i])); } - for (int i = 0; i < num_output; i++) - { - weight_hc_RUN[0] = float32_to_bfloat16(weight_hc_R[i]); - weight_hc_RUN[1] = float32_to_bfloat16(weight_hc_R_1[i]); - weight_hc_RUN[2] = float32_to_bfloat16(weight_hc_R_2[i]); - weight_hc_RUN[3] = float32_to_bfloat16(weight_hc_R_3[i]); - weight_hc_RUN[4] = float32_to_bfloat16(weight_hc_U[i]); - weight_hc_RUN[5] = float32_to_bfloat16(weight_hc_U_1[i]); - weight_hc_RUN[6] = float32_to_bfloat16(weight_hc_U_2[i]); - weight_hc_RUN[7] = float32_to_bfloat16(weight_hc_U_3[i]); - - weight_hc_RUN += 8; - } + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } + } + if (elemtype == 2) + { + // fp16 + for (int t = 0; t < T; t++) + { + const unsigned short* x = bottom_blob.row(t); + float absmax = 0.f; for (int i = 0; i < size; i++) { - weight_xc_RUN[0] = float32_to_bfloat16(weight_xc_N[i]); - weight_xc_RUN[1] = float32_to_bfloat16(weight_xc_N_1[i]); - weight_xc_RUN[2] = float32_to_bfloat16(weight_xc_N_2[i]); - weight_xc_RUN[3] = float32_to_bfloat16(weight_xc_N_3[i]); - - weight_xc_RUN += 4; + absmax = std::max(absmax, (float)fabs(float16_to_float32(x[i]))); } - for (int i = 0; i < num_output; i++) - { - weight_hc_RUN[0] = float32_to_bfloat16(weight_hc_N[i]); - weight_hc_RUN[1] = float32_to_bfloat16(weight_hc_N_1[i]); - weight_hc_RUN[2] = float32_to_bfloat16(weight_hc_N_2[i]); - weight_hc_RUN[3] = float32_to_bfloat16(weight_hc_N_3[i]); - - weight_hc_RUN += 4; - } + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; } -#endif // __ARM_NEON - for (; q < num_output; q++) + } + if (elemtype == 4) + { + // bf16 + for (int t = 0; t < T; t++) { - bias_c_RUBNWN[0] = float32_to_bfloat16(bias_c_R[q]); - bias_c_RUBNWN[1] = float32_to_bfloat16(bias_c_U[q]); - bias_c_RUBNWN[2] = float32_to_bfloat16(bias_c_BN[q]); - bias_c_RUBNWN[3] = float32_to_bfloat16(bias_c_WN[q]); - - bias_c_RUBNWN += 4; - - const float* weight_xc_R = weight_xc.row(num_output * 0 + q); - const float* weight_xc_U = weight_xc.row(num_output * 1 + q); - const float* weight_xc_N = weight_xc.row(num_output * 2 + q); - - const float* weight_hc_R = weight_hc.row(num_output * 0 + q); - const float* weight_hc_U = weight_hc.row(num_output * 1 + q); - const float* weight_hc_N = weight_hc.row(num_output * 2 + q); - -#if __ARM_NEON - unsigned short* weight_xc_RUN = weight_xc_data_packed_dr.row(q / 4 + q % 4); - unsigned short* weight_hc_RUN = weight_hc_data_packed_dr.row(q / 4 + q % 4); -#else - unsigned short* weight_xc_RUN = weight_xc_data_packed_dr.row(q); - unsigned short* weight_hc_RUN = weight_hc_data_packed_dr.row(q); -#endif // __ARM_NEON + const unsigned short* x = bottom_blob.row(t); + float absmax = 0.f; for (int i = 0; i < size; i++) { - weight_xc_RUN[0] = float32_to_bfloat16(weight_xc_R[i]); - weight_xc_RUN[1] = float32_to_bfloat16(weight_xc_U[i]); - - weight_xc_RUN += 2; + absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(x[i]))); } - for (int i = 0; i < num_output; i++) - { - weight_hc_RUN[0] = float32_to_bfloat16(weight_hc_R[i]); - weight_hc_RUN[1] = float32_to_bfloat16(weight_hc_U[i]); - - weight_hc_RUN += 2; - } + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } + } - for (int i = 0; i < size; i++) - { - weight_xc_RUN[0] = float32_to_bfloat16(weight_xc_N[i]); + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt); +} - weight_xc_RUN += 1; - } +int GRU_arm::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int elemtype = 1; // fp32 + { + int elembits = bottom_blob.elembits(); - for (int i = 0; i < num_output; i++) - { - weight_hc_RUN[0] = float32_to_bfloat16(weight_hc_N[i]); + // clang-format off + // *INDENT-OFF* - weight_hc_RUN += 1; - } +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + elemtype = 2; // fp16 + } + else +#endif +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + { + elemtype = 4; // bf16 + } + else +#endif + { + // fp32 } - } - if (opt.lightmode) - { - weight_xc_data.release(); - bias_c_data.release(); - weight_hc_data.release(); + // *INDENT-ON* + // clang-format on } - return 0; -} - -int GRU_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const -{ int T = bottom_blob.h; + size_t elemsize = bottom_blob.elemsize; int num_directions = direction == 2 ? 2 : 1; @@ -2583,99 +1654,118 @@ int GRU_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& return -100; hidden.fill(0.f); - top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); + top_blob.create(num_output * num_directions, T, elemsize, opt.blob_allocator); if (top_blob.empty()) return -100; + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, elemtype, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + // Uni directional if (direction == 0 || direction == 1) { -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_bf16s_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = gru_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden, opt); } if (direction == 2) { - Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); + Mat top_blob_forward(num_output, T, elemsize, opt.workspace_allocator); if (top_blob_forward.empty()) return -100; - Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); + Mat top_blob_reverse(num_output, T, elemsize, opt.workspace_allocator); if (top_blob_reverse.empty()) return -100; -#if NCNN_INT8 - if (int8_scale_term) { - int ret = gru_bf16s_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = gru_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, elemtype, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden, opt); } hidden.fill(0.f); -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_bf16s_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif { - int ret = gru_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret != 0) - return ret; + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, elemtype, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), hidden, opt); } // concat w for (int i = 0; i < T; i++) { - const unsigned short* pf = top_blob_forward.row(i); - const unsigned short* pr = top_blob_reverse.row(i); - unsigned short* ptr = top_blob.row(i); + const unsigned char* pf = top_blob_forward.row(i); + const unsigned char* pr = top_blob_reverse.row(i); + unsigned char* ptr = top_blob.row(i); - memcpy(ptr, pf, num_output * sizeof(unsigned short)); - memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short)); + memcpy(ptr, pf, num_output * elemsize); + memcpy(ptr + num_output * elemsize, pr, num_output * elemsize); } } return 0; } -int GRU_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +int GRU_arm::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { const Mat& bottom_blob = bottom_blobs[0]; + + int elemtype = 1; // fp32 + { + int elembits = bottom_blob.elembits(); + + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + elemtype = 2; // fp16 + } + else +#endif +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + { + elemtype = 4; // bf16 + } + else +#endif + { + // fp32 + } + + // *INDENT-ON* + // clang-format on + } + int T = bottom_blob.h; + size_t elemsize = bottom_blob.elemsize; int num_directions = direction == 2 ? 2 : 1; Mat hidden; Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; if (bottom_blobs.size() == 2) { - Option opt_cast = opt; - opt_cast.blob_allocator = hidden_allocator; - cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt_cast); + if (elemtype == 1) + { + hidden = bottom_blobs[1].clone(); + } + if (elemtype == 2) + { + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_allocator; + cast_float16_to_float32(bottom_blobs[1], hidden, opt_cast); + } + if (elemtype == 4) + { + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_allocator; + cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt_cast); + } } else { @@ -2686,90 +1776,76 @@ int GRU_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector(i); - const unsigned short* pr = top_blob_reverse.row(i); - unsigned short* ptr = top_blob.row(i); + const unsigned char* pf = top_blob_forward.row(i); + const unsigned char* pr = top_blob_reverse.row(i); + unsigned char* ptr = top_blob.row(i); - memcpy(ptr, pf, num_output * sizeof(unsigned short)); - memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short)); + memcpy(ptr, pf, num_output * elemsize); + memcpy(ptr + num_output * elemsize, pr, num_output * elemsize); } } if (top_blobs.size() == 2) { - cast_float32_to_bfloat16(hidden, top_blobs[1], opt); + if (elemtype == 1) + { + top_blobs[1] = hidden; + } + if (elemtype == 2) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + } + if (elemtype == 4) + { + cast_float32_to_bfloat16(hidden, top_blobs[1], opt); + } } return 0; } -#endif // NCNN_BF16 +#endif // NCNN_INT8 } // namespace ncnn diff --git a/src/layer/arm/gru_arm.h b/src/layer/arm/gru_arm.h index b44a1f38be7..aba1608df90 100644 --- a/src/layer/arm/gru_arm.h +++ b/src/layer/arm/gru_arm.h @@ -29,9 +29,6 @@ class GRU_arm : public GRU virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; protected: -#if NCNN_INT8 - int create_pipeline_int8(const Option& opt); -#endif #if NCNN_ARM82 int create_pipeline_fp16s(const Option& opt); int forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; @@ -42,15 +39,22 @@ class GRU_arm : public GRU int forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); + void dynamic_quantize(const Mat& bottom_blob, int elemtype, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const; + int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif public: Mat weight_xc_data_packed; Mat bias_c_data_packed; Mat weight_hc_data_packed; + Mat weight_data_tm; + #if NCNN_INT8 - Mat weight_hc_data_int8_descales_packed; - Mat weight_xc_data_int8_descales_packed; + Mat weight_data_tm_int8_descales; #endif }; diff --git a/src/layer/arm/gru_arm_asimddp.cpp b/src/layer/arm/gru_arm_asimddp.cpp new file mode 100644 index 00000000000..3de7ed84ead --- /dev/null +++ b/src/layer/arm/gru_arm_asimddp.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "arm_activation.h" +#include "arm_usability.h" + +namespace ncnn { + +#include "gru_int8.h" + +void gru_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, const Option& opt) +{ + gru_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, opt); +} + +void gru_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt) +{ + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, hidden_state, opt); +} + +} // namespace ncnn diff --git a/src/layer/arm/gru_arm_asimdhp.cpp b/src/layer/arm/gru_arm_asimdhp.cpp index 4278d2289ea..3a3d92d5d57 100644 --- a/src/layer/arm/gru_arm_asimdhp.cpp +++ b/src/layer/arm/gru_arm_asimdhp.cpp @@ -732,900 +732,8 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M return 0; } -#if NCNN_INT8 -static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) -{ - int size = bottom_blob.w; - int T = bottom_blob.h; - - int num_output = top_blob.w; - - // 2 x num_output - Mat gates(4 * 2, num_output / 4 + num_output % 4, 4u, opt.workspace_allocator); - if (gates.empty()) - return -100; - - // unroll - for (int t = 0; t < T; t++) - { - int ti = reverse ? T - 1 - t : t; - - int nn_num_output = num_output >> 2; - int remain_num_output_start = nn_num_output << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const __fp16* x = bottom_blob.row(ti); - - // gate reset update - const __fp16* bias_c_RUBNWN = (const __fp16*)bias_c + q * 4; - - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4); - - const __fp16* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4); - const __fp16* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4); - - float16x8_t _descale_xc_RU = vld1q_f16(weight_xc_int8_descales_RUN); - float16x8_t _descale_hc_RU = vld1q_f16(weight_hc_int8_descales_RUN); - - float16x8_t _RU = vld1q_f16(bias_c_RUBNWN); - float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f); - float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f); - float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f); - - int i = 0; - for (; i + 3 < size; i += 4) - { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v6.16b, v7.16b}, [%1], #32 \n" - "ld1 {v4.4h}, [%0], #8 \n" - "sxtl v0.8h, v6.8b \n" - "sxtl2 v1.8h, v6.16b \n" - "sxtl v2.8h, v7.8b \n" - "sxtl2 v3.8h, v7.16b \n" - "scvtf v0.8h, v0.8h \n" - "scvtf v1.8h, v1.8h \n" - "scvtf v2.8h, v2.8h \n" - "scvtf v3.8h, v3.8h \n" - "fmul v0.8h, v0.8h, %12.8h \n" - "fmul v1.8h, v1.8h, %12.8h \n" - "fmul v2.8h, v2.8h, %12.8h \n" - "fmul v3.8h, v3.8h, %12.8h \n" - "fmla %2.8h, v0.8h, v4.h[0] \n" - "fmla %3.8h, v1.8h, v4.h[1] \n" - "fmla %4.8h, v2.8h, v4.h[2] \n" - "fmla %5.8h, v3.8h, v4.h[3] \n" - : "=r"(x), - "=r"(weight_xc_int8_RUN), - "=w"(_RU), - "=w"(_sum1), - "=w"(_sum2), - "=w"(_sum3) - : "0"(x), - "1"(weight_xc_int8_RUN), - "2"(_RU), - "3"(_sum1), - "4"(_sum2), - "5"(_sum3), - "w"(_descale_xc_RU) - : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _x = vld1_f16(x); - - int8x16_t _weight_xc_RU01 = vld1q_s8(weight_xc_int8_RUN); - int8x16_t _weight_xc_RU23 = vld1q_s8(weight_xc_int8_RUN + 16); - - float16x8_t _w0 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_RU01))), _descale_xc_RU); - float16x8_t _w1 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_RU01))), _descale_xc_RU); - float16x8_t _w2 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_RU23))), _descale_xc_RU); - float16x8_t _w3 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_RU23))), _descale_xc_RU); - - _RU = vfmaq_lane_f16(_RU, _w0, _x, 0); - _sum1 = vfmaq_lane_f16(_sum1, _w1, _x, 1); - _sum2 = vfmaq_lane_f16(_sum2, _w2, _x, 2); - _sum3 = vfmaq_lane_f16(_sum3, _w3, _x, 3); - - x += 4; - weight_xc_int8_RUN += 32; -#endif // NCNN_GNU_INLINE_ASM - } - for (; i < size; i++) - { - __fp16 xi = *x++; - - float16x8_t _xi = vdupq_n_f16(xi); - - float16x8_t _weight_xc_RU = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))), _descale_xc_RU); - - _RU = vfmaq_f16(_RU, _weight_xc_RU, _xi); - - weight_xc_int8_RUN += 8; - } - - const float* hidden_ptr = hidden_state; - - i = 0; - for (; i + 3 < num_output; i += 4) - { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v6.8h, v7.8h}, [%1], #32 \n" - "ld1 {v4.4s}, [%0], #16 \n" - "sxtl v0.8h, v6.8b \n" - "sxtl2 v1.8h, v6.16b \n" - "sxtl v2.8h, v7.8b \n" - "sxtl2 v3.8h, v7.16b \n" - "scvtf v0.8h, v0.8h \n" - "scvtf v1.8h, v1.8h \n" - "scvtf v2.8h, v2.8h \n" - "scvtf v3.8h, v3.8h \n" - "fcvtn v4.4h, v4.4s \n" - "fmul v0.8h, v0.8h, %12.8h \n" - "fmul v1.8h, v1.8h, %12.8h \n" - "fmul v2.8h, v2.8h, %12.8h \n" - "fmul v3.8h, v3.8h, %12.8h \n" - "fmla %2.8h, v0.8h, v4.h[0] \n" - "fmla %3.8h, v1.8h, v4.h[1] \n" - "fmla %4.8h, v2.8h, v4.h[2] \n" - "fmla %5.8h, v3.8h, v4.h[3] \n" - : "=r"(hidden_ptr), - "=r"(weight_hc_int8_RUN), - "=w"(_RU), - "=w"(_sum1), - "=w"(_sum2), - "=w"(_sum3) - : "0"(hidden_ptr), - "1"(weight_hc_int8_RUN), - "2"(_RU), - "3"(_sum1), - "4"(_sum2), - "5"(_sum3), - "w"(_descale_hc_RU) - : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); - - int8x16_t _weight_hc_RU01 = vld1q_s8(weight_hc_int8_RUN); - int8x16_t _weight_hc_RU23 = vld1q_s8(weight_hc_int8_RUN + 16); - - float16x8_t _w0 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_RU01))), _descale_hc_RU); - float16x8_t _w1 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_RU01))), _descale_hc_RU); - float16x8_t _w2 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_RU23))), _descale_hc_RU); - float16x8_t _w3 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_RU23))), _descale_hc_RU); - - _RU = vfmaq_lane_f16(_RU, _w0, _h_cont, 0); - _sum1 = vfmaq_lane_f16(_sum1, _w1, _h_cont, 1); - _sum2 = vfmaq_lane_f16(_sum2, _w2, _h_cont, 2); - _sum3 = vfmaq_lane_f16(_sum3, _w3, _h_cont, 3); - - hidden_ptr += 4; - weight_hc_int8_RUN += 32; -#endif // NCNN_GNU_INLINE_ASM - } - for (; i < num_output; i++) - { - float h_cont = *hidden_ptr++; - - float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont); - - float16x8_t _weight_hc_RU = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))), _descale_hc_RU); - - _RU = vfmaq_f16(_RU, _weight_hc_RU, _h_cont); - - weight_hc_int8_RUN += 8; - } - - _RU = vaddq_f16(_RU, _sum1); - _sum2 = vaddq_f16(_sum2, _sum3); - _RU = vaddq_f16(_RU, _sum2); - - // sigmoid(R) - // sigmoid(U) - float32x4_t _R32 = sigmoid_ps(vcvt_f32_f16(vget_low_f16(_RU))); - float32x4_t _U32 = sigmoid_ps(vcvt_f32_f16(vget_high_f16(_RU))); - - x -= size; - hidden_ptr = hidden_state; - - // gate new - float16x4_t _gru_N = vld1_f16(bias_c_RUBNWN + 8); - float16x4_t _sum4 = vdup_n_f16((__fp16)0.f); - float16x4_t _sum5 = vdup_n_f16((__fp16)0.f); - float16x4_t _sum6 = vdup_n_f16((__fp16)0.f); - - float16x4_t _descale_xc_N = vld1_f16(weight_xc_int8_descales_RUN + 8); - float16x4_t _descale_hc_N = vld1_f16(weight_hc_int8_descales_RUN + 8); - float16x8_t _descale_xc_NN = vcombine_f16(_descale_xc_N, _descale_xc_N); - float16x8_t _descale_hc_NN = vcombine_f16(_descale_hc_N, _descale_hc_N); - - i = 0; - for (; i + 3 < num_output; i += 4) - { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v5.16b}, [%1], #16 \n" - "ld1 {v4.4s}, [%0], #16 \n" - "sxtl v0.8h, v5.8b \n" - "sxtl2 v2.8h, v5.16b \n" - "scvtf v0.8h, v0.8h \n" - "scvtf v2.8h, v2.8h \n" - "fcvtn v4.4h, v4.4s \n" - "fmul v0.8h, v0.8h, %12.8h \n" - "fmul v2.8h, v2.8h, %12.8h \n" - "mov v1.d[0], v0.d[1] \n" - "mov v3.d[0], v2.d[1] \n" - "fmla %2.4h, v0.4h, v4.h[0] \n" - "fmla %3.4h, v1.4h, v4.h[1] \n" - "fmla %4.4h, v2.4h, v4.h[2] \n" - "fmla %5.4h, v3.4h, v4.h[3] \n" - : "=r"(hidden_ptr), - "=r"(weight_hc_int8_RUN), - "=w"(_gru_N), - "=w"(_sum4), - "=w"(_sum5), - "=w"(_sum6) - : "0"(hidden_ptr), - "1"(weight_hc_int8_RUN), - "2"(_gru_N), - "3"(_sum4), - "4"(_sum5), - "5"(_sum6), - "w"(_descale_hc_NN) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); - - int8x16_t _weight_hc_N0123 = vld1q_s8(weight_hc_int8_RUN); - float16x8_t _weight_hc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_N0123))), _descale_hc_NN); - float16x8_t _weight_hc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_N0123))), _descale_hc_NN); - - float16x4_t _w0 = vget_low_f16(_weight_hc_N01); - float16x4_t _w1 = vget_high_f16(_weight_hc_N01); - float16x4_t _w2 = vget_low_f16(_weight_hc_N23); - float16x4_t _w3 = vget_high_f16(_weight_hc_N23); - - _gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0); - _sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1); - _sum5 = vfma_lane_f16(_sum5, _w2, _h_cont, 2); - _sum6 = vfma_lane_f16(_sum6, _w3, _h_cont, 3); - - hidden_ptr += 4; - weight_hc_int8_RUN += 16; -#endif // NCNN_GNU_INLINE_ASM - } - for (; i < num_output; i++) - { - float h_cont = *hidden_ptr++; - - float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); - int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0])); - float16x4_t _weight_hc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_hc_N); - _gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont); - - weight_hc_int8_RUN += 4; - } - - _gru_N = vadd_f16(_gru_N, _sum4); - _sum5 = vadd_f16(_sum5, _sum6); - _gru_N = vadd_f16(_gru_N, _sum5); - - _gru_N = vfma_f16(vld1_f16(bias_c_RUBNWN + 12), vcvt_f16_f32(_R32), _gru_N); - _sum4 = vdup_n_f16((__fp16)0.f); - _sum5 = vdup_n_f16((__fp16)0.f); - _sum6 = vdup_n_f16((__fp16)0.f); - - i = 0; - for (; i + 3 < size; i += 4) - { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v5.16b}, [%1], #16 \n" - "ld1 {v4.4h}, [%0], #8 \n" - "sxtl v0.8h, v5.8b \n" - "sxtl2 v2.8h, v5.16b \n" - "scvtf v0.8h, v0.8h \n" - "scvtf v2.8h, v2.8h \n" - "fmul v0.8h, v0.8h, %12.8h \n" - "fmul v2.8h, v2.8h, %12.8h \n" - "mov v1.d[0], v0.d[1] \n" - "mov v3.d[0], v2.d[1] \n" - "fmla %2.4h, v0.4h, v4.h[0] \n" - "fmla %3.4h, v1.4h, v4.h[1] \n" - "fmla %4.4h, v2.4h, v4.h[2] \n" - "fmla %5.4h, v3.4h, v4.h[3] \n" - : "=r"(x), - "=r"(weight_xc_int8_RUN), - "=w"(_gru_N), - "=w"(_sum4), - "=w"(_sum5), - "=w"(_sum6) - : "0"(x), - "1"(weight_xc_int8_RUN), - "2"(_gru_N), - "3"(_sum4), - "4"(_sum5), - "5"(_sum6), - "w"(_descale_xc_NN) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _x = vld1_f16(x); - - int8x16_t _weight_xc_N0123 = vld1q_s8(weight_xc_int8_RUN); - float16x8_t _weight_xc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_N0123))), _descale_xc_NN); - float16x8_t _weight_xc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_N0123))), _descale_xc_NN); - - float16x4_t _w0 = vget_low_f16(_weight_xc_N01); - float16x4_t _w1 = vget_high_f16(_weight_xc_N01); - float16x4_t _w2 = vget_low_f16(_weight_xc_N23); - float16x4_t _w3 = vget_high_f16(_weight_xc_N23); - - _gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0); - _sum4 = vfma_lane_f16(_sum4, _w1, _x, 1); - _sum5 = vfma_lane_f16(_sum5, _w2, _x, 2); - _sum6 = vfma_lane_f16(_sum6, _w3, _x, 3); - - x += 4; - weight_xc_int8_RUN += 16; -#endif // NCNN_GNU_INLINE_ASM - } - for (; i < size; i++) - { - __fp16 xi = *x++; - - float16x4_t _xi = vdup_n_f16(xi); - int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0])); - float16x4_t _weight_xc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_xc_N); - _gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi); - - weight_xc_int8_RUN += 4; - } - - _gru_N = vadd_f16(_gru_N, _sum4); - _sum5 = vadd_f16(_sum5, _sum6); - _gru_N = vadd_f16(_gru_N, _sum5); - - // tanh(N) - float32x4_t _N32 = tanh_ps(vcvt_f32_f16(_gru_N)); - - float* gates_data = gates.row(q / 4); - - vst1q_f32(gates_data, _U32); - vst1q_f32(gates_data + 4, _N32); - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const __fp16* x = bottom_blob.row(ti); - - // gate reset update - const __fp16* bias_c_RUBNWN = (const __fp16*)bias_c + q * 4; - - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4 + q % 4); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4 + q % 4); - const __fp16* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4 + q % 4); - const __fp16* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4 + q % 4); - - const __fp16 descale_xc_R = weight_xc_int8_descales_RUN[0]; - const __fp16 descale_xc_U = weight_xc_int8_descales_RUN[1]; - const __fp16 descale_xc_N = weight_xc_int8_descales_RUN[2]; - - const __fp16 descale_hc_R = weight_hc_int8_descales_RUN[0]; - const __fp16 descale_hc_U = weight_hc_int8_descales_RUN[1]; - const __fp16 descale_hc_N = weight_hc_int8_descales_RUN[2]; - - __fp16 R = bias_c_RUBNWN[0]; - __fp16 U = bias_c_RUBNWN[1]; - - for (int i = 0; i < size; i++) - { - __fp16 xi = x[i]; - - R += (__fp16)weight_xc_int8_RUN[0] * descale_xc_R * xi; - U += (__fp16)weight_xc_int8_RUN[1] * descale_xc_U * xi; - - weight_xc_int8_RUN += 2; - } - - for (int i = 0; i < num_output; i++) - { - __fp16 h_cont = (__fp16)hidden_state[i]; - - R += (__fp16)weight_hc_int8_RUN[0] * descale_hc_R * h_cont; - U += (__fp16)weight_hc_int8_RUN[1] * descale_hc_U * h_cont; - - weight_hc_int8_RUN += 2; - } - - // sigmoid(R) - // sigmoid(U) - float R32 = 1.f / (1.f + expf((float)-R)); - float U32 = 1.f / (1.f + expf((float)-U)); - - // gate new - __fp16 N = bias_c_RUBNWN[2]; - - for (int i = 0; i < num_output; i++) - { - __fp16 h_cont = (__fp16)hidden_state[i]; - - N += (__fp16)weight_hc_int8_RUN[0] * descale_hc_N * h_cont; - - weight_hc_int8_RUN += 1; - } - - N = bias_c_RUBNWN[3] + (__fp16)R32 * N; - - for (int i = 0; i < size; i++) - { - __fp16 xi = x[i]; - - N += (__fp16)weight_xc_int8_RUN[0] * descale_xc_N * xi; - - weight_xc_int8_RUN += 1; - } - - // tanh(N) - float N32 = tanhf((float)N); - - float* gates_data = gates.row(q / 4 + q % 4); - - gates_data[0] = U32; - gates_data[1] = N32; - } - - // h_t := (1 - update) .* new + update .* h_{t-1} - __fp16* output_data = top_blob.row<__fp16>(ti); - - float* hidden_ptr = hidden_state; - - nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const float* gates_data = gates.row(q / 4); - - float32x4_t _gru_U = vld1q_f32(gates_data); - float32x4_t _gru_N = vld1q_f32(gates_data + 4); - - float32x4_t _gru_H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U), _gru_N), vmulq_f32(_gru_U, vld1q_f32(hidden_ptr + q))); - - vst1q_f32(hidden_ptr + q, _gru_H); - vst1_f16(output_data + q, vcvt_f16_f32(_gru_H)); - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const float* gates_data = gates.row(q / 4 + q % 4); - - float U = gates_data[0]; - float N = gates_data[1]; - - float H = (1 - U) * N + U * hidden_ptr[q]; - - hidden_ptr[q] = H; - output_data[q] = (__fp16)H; - } - } - - return 0; -} - -static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) -{ - if (opt.use_fp16_arithmetic) - return gru_fp16sa_int8(bottom_blob, top_blob, reverse, weight_xc_int8, weight_xc_int8_descales, bias_c, weight_hc_int8, weight_hc_int8_descales, hidden_state, opt); - - int size = bottom_blob.w; - int T = bottom_blob.h; - - int num_output = top_blob.w; - - // 2 x num_output - Mat gates(4 * 2, num_output / 4 + num_output % 4, 4u, opt.workspace_allocator); - if (gates.empty()) - return -100; - - // unroll - for (int t = 0; t < T; t++) - { - int ti = reverse ? T - 1 - t : t; - - int nn_num_output = num_output >> 2; - int remain_num_output_start = nn_num_output << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const __fp16* x = bottom_blob.row(ti); - - // gate reset update - const __fp16* bias_c_RUBNWN = (const __fp16*)bias_c + q * 4; - - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4); - - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4); - - float32x4_t _descale_xc_R = vld1q_f32(weight_xc_int8_descales_RUN); - float32x4_t _descale_xc_U = vld1q_f32(weight_xc_int8_descales_RUN + 4); - float32x4_t _descale_hc_R = vld1q_f32(weight_hc_int8_descales_RUN); - float32x4_t _descale_hc_U = vld1q_f32(weight_hc_int8_descales_RUN + 4); - - float32x4_t _gru_R = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN)); - float32x4_t _gru_U = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 4)); - float32x4_t _sum1 = vdupq_n_f32(0.f); - float32x4_t _sum2 = vdupq_n_f32(0.f); - float32x4_t _sum3 = vdupq_n_f32(0.f); - float32x4_t _sum4 = vdupq_n_f32(0.f); - float32x4_t _sum5 = vdupq_n_f32(0.f); - float32x4_t _sum6 = vdupq_n_f32(0.f); - - int i = 0; - for (; i + 3 < size; i += 4) - { - float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); - - int8x16_t _weight_xc_RU01 = vld1q_s8(weight_xc_int8_RUN); - int8x16_t _weight_xc_RU23 = vld1q_s8(weight_xc_int8_RUN + 16); - - int16x8_t _weight_xc_RU0 = vmovl_s8(vget_low_s8(_weight_xc_RU01)); - int16x8_t _weight_xc_RU1 = vmovl_s8(vget_high_s8(_weight_xc_RU01)); - int16x8_t _weight_xc_RU2 = vmovl_s8(vget_low_s8(_weight_xc_RU23)); - int16x8_t _weight_xc_RU3 = vmovl_s8(vget_high_s8(_weight_xc_RU23)); - - float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU0))), _descale_xc_R); - float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU0))), _descale_xc_U); - float32x4_t _weight_xc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU1))), _descale_xc_R); - float32x4_t _weight_xc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU1))), _descale_xc_U); - float32x4_t _weight_xc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU2))), _descale_xc_R); - float32x4_t _weight_xc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU2))), _descale_xc_U); - float32x4_t _weight_xc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU3))), _descale_xc_R); - float32x4_t _weight_xc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU3))), _descale_xc_U); - - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_xc_U_2, _xi, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); - - weight_xc_int8_RUN += 32; - } - for (; i < size; i++) - { - __fp16 xi = x[i]; - - float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - - int16x8_t _weight_xc_RU = vmovl_s8(vld1_s8(weight_xc_int8_RUN)); - float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU))), _descale_xc_R); - float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU))), _descale_xc_U); - - _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); - _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); - - weight_xc_int8_RUN += 8; - } - - i = 0; - for (; i + 3 < num_output; i += 4) - { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - - int8x16_t _weight_hc_RU01 = vld1q_s8(weight_hc_int8_RUN); - int8x16_t _weight_hc_RU23 = vld1q_s8(weight_hc_int8_RUN + 16); - - int16x8_t _weight_hc_RU0 = vmovl_s8(vget_low_s8(_weight_hc_RU01)); - int16x8_t _weight_hc_RU1 = vmovl_s8(vget_high_s8(_weight_hc_RU01)); - int16x8_t _weight_hc_RU2 = vmovl_s8(vget_low_s8(_weight_hc_RU23)); - int16x8_t _weight_hc_RU3 = vmovl_s8(vget_high_s8(_weight_hc_RU23)); - - float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU0))), _descale_hc_R); - float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU0))), _descale_hc_U); - float32x4_t _weight_hc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU1))), _descale_hc_R); - float32x4_t _weight_hc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU1))), _descale_hc_U); - float32x4_t _weight_hc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU2))), _descale_hc_R); - float32x4_t _weight_hc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU2))), _descale_hc_U); - float32x4_t _weight_hc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU3))), _descale_hc_R); - float32x4_t _weight_hc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU3))), _descale_hc_U); - - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_hc_U_2, _h_cont, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); - - weight_hc_int8_RUN += 32; - } - for (; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - float32x4_t _h_cont = vdupq_n_f32(h_cont); - - int16x8_t _weight_hc_RU = vmovl_s8(vld1_s8(weight_hc_int8_RUN)); - float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU))), _descale_hc_R); - float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU))), _descale_hc_U); - - _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); - _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); - - weight_hc_int8_RUN += 8; - } - - _gru_R = vaddq_f32(_gru_R, _sum1); - _gru_U = vaddq_f32(_gru_U, _sum2); - _sum3 = vaddq_f32(_sum3, _sum5); - _sum4 = vaddq_f32(_sum4, _sum6); - _gru_R = vaddq_f32(_gru_R, _sum3); - _gru_U = vaddq_f32(_gru_U, _sum4); - - // sigmoid(R) - // sigmoid(U) - _gru_R = sigmoid_ps(_gru_R); - _gru_U = sigmoid_ps(_gru_U); - - // gate new - float32x4_t _gru_N = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 8)); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); - - float32x4_t _descale_xc_N = vld1q_f32(weight_xc_int8_descales_RUN + 8); - float32x4_t _descale_hc_N = vld1q_f32(weight_hc_int8_descales_RUN + 8); - - i = 0; - for (; i + 3 < num_output; i += 4) - { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - - int8x16_t _weight_hc_N0123 = vld1q_s8(weight_hc_int8_RUN); - int16x8_t _weight_hc_N01 = vmovl_s8(vget_low_s8(_weight_hc_N0123)); - int16x8_t _weight_hc_N23 = vmovl_s8(vget_high_s8(_weight_hc_N0123)); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N01))), _descale_hc_N); - float32x4_t _weight_hc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N01))), _descale_hc_N); - float32x4_t _weight_hc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N23))), _descale_hc_N); - float32x4_t _weight_hc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N23))), _descale_hc_N); - - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); - - weight_hc_int8_RUN += 16; - } - for (; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - float32x4_t _h_cont = vdupq_n_f32(h_cont); - int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0])); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc_N); - _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); - - weight_hc_int8_RUN += 4; - } - - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); - - _gru_N = vmlaq_f32(vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); - - i = 0; - for (; i + 3 < size; i += 4) - { - float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); - - int8x16_t _weight_xc_N0123 = vld1q_s8(weight_xc_int8_RUN); - int16x8_t _weight_xc_N01 = vmovl_s8(vget_low_s8(_weight_xc_N0123)); - int16x8_t _weight_xc_N23 = vmovl_s8(vget_high_s8(_weight_xc_N0123)); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N01))), _descale_xc_N); - float32x4_t _weight_xc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N01))), _descale_xc_N); - float32x4_t _weight_xc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N23))), _descale_xc_N); - float32x4_t _weight_xc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N23))), _descale_xc_N); - - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); - - weight_xc_int8_RUN += 16; - } - for (; i < size; i++) - { - __fp16 xi = x[i]; - - float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0])); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc_N); - _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); - - weight_xc_int8_RUN += 4; - } - - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); - - // tanh(N) - _gru_N = tanh_ps(_gru_N); - - float* gates_data = gates.row(q / 4); - - vst1q_f32(gates_data, _gru_U); - vst1q_f32(gates_data + 4, _gru_N); - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const __fp16* x = bottom_blob.row(ti); - - // gate reset update - const __fp16* bias_c_RUBNWN = (const __fp16*)bias_c + q * 4; - - const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4 + q % 4); - const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4 + q % 4); - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4 + q % 4); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4 + q % 4); - - const float descale_xc_R = weight_xc_int8_descales_RUN[0]; - const float descale_xc_U = weight_xc_int8_descales_RUN[1]; - const float descale_xc_N = weight_xc_int8_descales_RUN[2]; - - const float descale_hc_R = weight_hc_int8_descales_RUN[0]; - const float descale_hc_U = weight_hc_int8_descales_RUN[1]; - const float descale_hc_N = weight_hc_int8_descales_RUN[2]; - - float R = (float)bias_c_RUBNWN[0]; - float U = (float)bias_c_RUBNWN[1]; - - for (int i = 0; i < size; i++) - { - float xi = (float)x[i]; - - R += weight_xc_int8_RUN[0] * descale_xc_R * xi; - U += weight_xc_int8_RUN[1] * descale_xc_U * xi; - - weight_xc_int8_RUN += 2; - } - - for (int i = 0; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - R += weight_hc_int8_RUN[0] * descale_hc_R * h_cont; - U += weight_hc_int8_RUN[1] * descale_hc_U * h_cont; - - weight_hc_int8_RUN += 2; - } - - // sigmoid(R) - // sigmoid(U) - R = 1.f / (1.f + expf(-R)); - U = 1.f / (1.f + expf(-U)); - - // gate new - float N = (float)bias_c_RUBNWN[2]; - - for (int i = 0; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - N += weight_hc_int8_RUN[0] * descale_hc_N * h_cont; - - weight_hc_int8_RUN += 1; - } - - N = (float)bias_c_RUBNWN[3] + R * N; - - for (int i = 0; i < size; i++) - { - float xi = (float)x[i]; - - N += weight_xc_int8_RUN[0] * descale_xc_N * xi; - - weight_xc_int8_RUN += 1; - } - - // tanh(N) - N = tanhf(N); - - float* gates_data = gates.row(q / 4 + q % 4); - - gates_data[0] = U; - gates_data[1] = N; - } - - // h_t := (1 - update) .* new + update .* h_{t-1} - __fp16* output_data = top_blob.row<__fp16>(ti); - - float* hidden_ptr = hidden_state; - - nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const float* gates_data = gates.row(q / 4); - - float32x4_t _gru_U = vld1q_f32(gates_data); - float32x4_t _gru_N = vld1q_f32(gates_data + 4); - - float32x4_t _gru_H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U), _gru_N), vmulq_f32(_gru_U, vld1q_f32(hidden_ptr + q))); - - vst1q_f32(hidden_ptr + q, _gru_H); - vst1_f16(output_data + q, vcvt_f16_f32(_gru_H)); - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const float* gates_data = gates.row(q / 4 + q % 4); - - float U = gates_data[0]; - float N = gates_data[1]; - - float H = (1 - U) * N + U * hidden_ptr[q]; - - hidden_ptr[q] = H; - output_data[q] = (__fp16)H; - } - } - - return 0; -} -#endif // NCNN_INT8 - int GRU_arm::create_pipeline_fp16s(const Option& opt) { -#if NCNN_INT8 - if (int8_scale_term) - { - create_pipeline_int8(opt); - - { - ncnn::Mat tmp; - cast_float32_to_float16(bias_c_data_packed, tmp, opt); - bias_c_data_packed = tmp; - } - - if (opt.use_fp16_arithmetic) - { - ncnn::Mat tmp; - cast_float32_to_float16(weight_xc_data_int8_descales_packed, tmp, opt); - weight_xc_data_int8_descales_packed = tmp; - } - - if (opt.use_fp16_arithmetic) - { - ncnn::Mat tmp; - cast_float32_to_float16(weight_hc_data_int8_descales_packed, tmp, opt); - weight_hc_data_int8_descales_packed = tmp; - } - - return 0; - } -#endif - // pack RUN int num_directions = direction == 2 ? 2 : 1; int size = weight_data_size / num_directions / num_output / 3; @@ -1838,20 +946,9 @@ int GRU_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& // Uni directional if (direction == 0 || direction == 1) { -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_fp16s_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif - { - int ret = gru_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } + int ret = gru_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; } if (direction == 2) @@ -1864,15 +961,6 @@ int GRU_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_fp16s_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif { int ret = gru_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); if (ret != 0) @@ -1881,15 +969,6 @@ int GRU_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& hidden.fill(0.f); -#if NCNN_INT8 - if (int8_scale_term) - { - int ret = gru_fp16s_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), hidden, opt); - if (ret != 0) - return ret; - } - else -#endif { int ret = gru_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); if (ret != 0) @@ -1941,20 +1020,9 @@ int GRU_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(num_output * 0 + q); + const signed char* weight_xc_U_0 = weight_xc_dr.row(num_output * 1 + q); + const signed char* weight_xc_N_0 = weight_xc_dr.row(num_output * 2 + q); + + const signed char* weight_xc_R_1 = weight_xc_dr.row(num_output * 0 + q + 1); + const signed char* weight_xc_U_1 = weight_xc_dr.row(num_output * 1 + q + 1); + const signed char* weight_xc_N_1 = weight_xc_dr.row(num_output * 2 + q + 1); + + const signed char* weight_xc_R_2 = weight_xc_dr.row(num_output * 0 + q + 2); + const signed char* weight_xc_U_2 = weight_xc_dr.row(num_output * 1 + q + 2); + const signed char* weight_xc_N_2 = weight_xc_dr.row(num_output * 2 + q + 2); + + const signed char* weight_xc_R_3 = weight_xc_dr.row(num_output * 0 + q + 3); + const signed char* weight_xc_U_3 = weight_xc_dr.row(num_output * 1 + q + 3); + const signed char* weight_xc_N_3 = weight_xc_dr.row(num_output * 2 + q + 3); + + const signed char* weight_hc_R_0 = weight_hc_dr.row(num_output * 0 + q); + const signed char* weight_hc_U_0 = weight_hc_dr.row(num_output * 1 + q); + const signed char* weight_hc_N_0 = weight_hc_dr.row(num_output * 2 + q); + + const signed char* weight_hc_R_1 = weight_hc_dr.row(num_output * 0 + q + 1); + const signed char* weight_hc_U_1 = weight_hc_dr.row(num_output * 1 + q + 1); + const signed char* weight_hc_N_1 = weight_hc_dr.row(num_output * 2 + q + 1); + + const signed char* weight_hc_R_2 = weight_hc_dr.row(num_output * 0 + q + 2); + const signed char* weight_hc_U_2 = weight_hc_dr.row(num_output * 1 + q + 2); + const signed char* weight_hc_N_2 = weight_hc_dr.row(num_output * 2 + q + 2); + + const signed char* weight_hc_R_3 = weight_hc_dr.row(num_output * 0 + q + 3); + const signed char* weight_hc_U_3 = weight_hc_dr.row(num_output * 1 + q + 3); + const signed char* weight_hc_N_3 = weight_hc_dr.row(num_output * 2 + q + 3); + + signed char* kptr = weight_data_tm_dr.row(q / 4); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 4); + + int i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_R_0[i]; + kptr[1] = weight_xc_R_0[i + 1]; + kptr[2] = weight_xc_R_0[i + 2]; + kptr[3] = weight_xc_R_0[i + 3]; + kptr[4] = weight_xc_R_1[i]; + kptr[5] = weight_xc_R_1[i + 1]; + kptr[6] = weight_xc_R_1[i + 2]; + kptr[7] = weight_xc_R_1[i + 3]; + kptr[8 + 0] = weight_xc_R_2[i]; + kptr[8 + 1] = weight_xc_R_2[i + 1]; + kptr[8 + 2] = weight_xc_R_2[i + 2]; + kptr[8 + 3] = weight_xc_R_2[i + 3]; + kptr[8 + 4] = weight_xc_R_3[i]; + kptr[8 + 5] = weight_xc_R_3[i + 1]; + kptr[8 + 6] = weight_xc_R_3[i + 2]; + kptr[8 + 7] = weight_xc_R_3[i + 3]; + kptr[16 + 0] = weight_xc_U_0[i]; + kptr[16 + 1] = weight_xc_U_0[i + 1]; + kptr[16 + 2] = weight_xc_U_0[i + 2]; + kptr[16 + 3] = weight_xc_U_0[i + 3]; + kptr[16 + 4] = weight_xc_U_1[i]; + kptr[16 + 5] = weight_xc_U_1[i + 1]; + kptr[16 + 6] = weight_xc_U_1[i + 2]; + kptr[16 + 7] = weight_xc_U_1[i + 3]; + kptr[24 + 0] = weight_xc_U_2[i]; + kptr[24 + 1] = weight_xc_U_2[i + 1]; + kptr[24 + 2] = weight_xc_U_2[i + 2]; + kptr[24 + 3] = weight_xc_U_2[i + 3]; + kptr[24 + 4] = weight_xc_U_3[i]; + kptr[24 + 5] = weight_xc_U_3[i + 1]; + kptr[24 + 6] = weight_xc_U_3[i + 2]; + kptr[24 + 7] = weight_xc_U_3[i + 3]; + + kptr += 32; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_R_0[i]; + kptr[1] = weight_xc_R_0[i + 1]; + kptr[2] = weight_xc_R_1[i]; + kptr[3] = weight_xc_R_1[i + 1]; + kptr[4] = weight_xc_R_2[i]; + kptr[5] = weight_xc_R_2[i + 1]; + kptr[6] = weight_xc_R_3[i]; + kptr[7] = weight_xc_R_3[i + 1]; + kptr[8 + 0] = weight_xc_U_0[i]; + kptr[8 + 1] = weight_xc_U_0[i + 1]; + kptr[8 + 2] = weight_xc_U_1[i]; + kptr[8 + 3] = weight_xc_U_1[i + 1]; + kptr[8 + 4] = weight_xc_U_2[i]; + kptr[8 + 5] = weight_xc_U_2[i + 1]; + kptr[8 + 6] = weight_xc_U_3[i]; + kptr[8 + 7] = weight_xc_U_3[i + 1]; + + kptr += 16; + } + for (; i < size; i++) + { + kptr[0] = weight_xc_R_0[i]; + kptr[1] = weight_xc_R_1[i]; + kptr[2] = weight_xc_R_2[i]; + kptr[3] = weight_xc_R_3[i]; + kptr[4] = weight_xc_U_0[i]; + kptr[5] = weight_xc_U_1[i]; + kptr[6] = weight_xc_U_2[i]; + kptr[7] = weight_xc_U_3[i]; + + kptr += 8; + } + + i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_R_0[i]; + kptr[1] = weight_hc_R_0[i + 1]; + kptr[2] = weight_hc_R_0[i + 2]; + kptr[3] = weight_hc_R_0[i + 3]; + kptr[4] = weight_hc_R_1[i]; + kptr[5] = weight_hc_R_1[i + 1]; + kptr[6] = weight_hc_R_1[i + 2]; + kptr[7] = weight_hc_R_1[i + 3]; + kptr[8 + 0] = weight_hc_R_2[i]; + kptr[8 + 1] = weight_hc_R_2[i + 1]; + kptr[8 + 2] = weight_hc_R_2[i + 2]; + kptr[8 + 3] = weight_hc_R_2[i + 3]; + kptr[8 + 4] = weight_hc_R_3[i]; + kptr[8 + 5] = weight_hc_R_3[i + 1]; + kptr[8 + 6] = weight_hc_R_3[i + 2]; + kptr[8 + 7] = weight_hc_R_3[i + 3]; + kptr[16 + 0] = weight_hc_U_0[i]; + kptr[16 + 1] = weight_hc_U_0[i + 1]; + kptr[16 + 2] = weight_hc_U_0[i + 2]; + kptr[16 + 3] = weight_hc_U_0[i + 3]; + kptr[16 + 4] = weight_hc_U_1[i]; + kptr[16 + 5] = weight_hc_U_1[i + 1]; + kptr[16 + 6] = weight_hc_U_1[i + 2]; + kptr[16 + 7] = weight_hc_U_1[i + 3]; + kptr[24 + 0] = weight_hc_U_2[i]; + kptr[24 + 1] = weight_hc_U_2[i + 1]; + kptr[24 + 2] = weight_hc_U_2[i + 2]; + kptr[24 + 3] = weight_hc_U_2[i + 3]; + kptr[24 + 4] = weight_hc_U_3[i]; + kptr[24 + 5] = weight_hc_U_3[i + 1]; + kptr[24 + 6] = weight_hc_U_3[i + 2]; + kptr[24 + 7] = weight_hc_U_3[i + 3]; + + kptr += 32; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_R_0[i]; + kptr[1] = weight_hc_R_0[i + 1]; + kptr[2] = weight_hc_R_1[i]; + kptr[3] = weight_hc_R_1[i + 1]; + kptr[4] = weight_hc_R_2[i]; + kptr[5] = weight_hc_R_2[i + 1]; + kptr[6] = weight_hc_R_3[i]; + kptr[7] = weight_hc_R_3[i + 1]; + kptr[8 + 0] = weight_hc_U_0[i]; + kptr[8 + 1] = weight_hc_U_0[i + 1]; + kptr[8 + 2] = weight_hc_U_1[i]; + kptr[8 + 3] = weight_hc_U_1[i + 1]; + kptr[8 + 4] = weight_hc_U_2[i]; + kptr[8 + 5] = weight_hc_U_2[i + 1]; + kptr[8 + 6] = weight_hc_U_3[i]; + kptr[8 + 7] = weight_hc_U_3[i + 1]; + + kptr += 16; + } + for (; i < num_output; i++) + { + kptr[0] = weight_hc_R_0[i]; + kptr[1] = weight_hc_R_1[i]; + kptr[2] = weight_hc_R_2[i]; + kptr[3] = weight_hc_R_3[i]; + kptr[4] = weight_hc_U_0[i]; + kptr[5] = weight_hc_U_1[i]; + kptr[6] = weight_hc_U_2[i]; + kptr[7] = weight_hc_U_3[i]; + + kptr += 8; + } + + i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_N_0[i]; + kptr[1] = weight_hc_N_0[i + 1]; + kptr[2] = weight_hc_N_0[i + 2]; + kptr[3] = weight_hc_N_0[i + 3]; + kptr[4] = weight_hc_N_1[i]; + kptr[5] = weight_hc_N_1[i + 1]; + kptr[6] = weight_hc_N_1[i + 2]; + kptr[7] = weight_hc_N_1[i + 3]; + kptr[8 + 0] = weight_hc_N_2[i]; + kptr[8 + 1] = weight_hc_N_2[i + 1]; + kptr[8 + 2] = weight_hc_N_2[i + 2]; + kptr[8 + 3] = weight_hc_N_2[i + 3]; + kptr[8 + 4] = weight_hc_N_3[i]; + kptr[8 + 5] = weight_hc_N_3[i + 1]; + kptr[8 + 6] = weight_hc_N_3[i + 2]; + kptr[8 + 7] = weight_hc_N_3[i + 3]; + + kptr += 16; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_N_0[i]; + kptr[1] = weight_hc_N_0[i + 1]; + kptr[2] = weight_hc_N_1[i]; + kptr[3] = weight_hc_N_1[i + 1]; + kptr[4] = weight_hc_N_2[i]; + kptr[5] = weight_hc_N_2[i + 1]; + kptr[6] = weight_hc_N_3[i]; + kptr[7] = weight_hc_N_3[i + 1]; + + kptr += 8; + } + for (; i < num_output; i++) + { + kptr[0] = weight_hc_N_0[i]; + kptr[1] = weight_hc_N_1[i]; + kptr[2] = weight_hc_N_2[i]; + kptr[3] = weight_hc_N_3[i]; + + kptr += 4; + } + + i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_N_0[i]; + kptr[1] = weight_xc_N_0[i + 1]; + kptr[2] = weight_xc_N_0[i + 2]; + kptr[3] = weight_xc_N_0[i + 3]; + kptr[4] = weight_xc_N_1[i]; + kptr[5] = weight_xc_N_1[i + 1]; + kptr[6] = weight_xc_N_1[i + 2]; + kptr[7] = weight_xc_N_1[i + 3]; + kptr[8 + 0] = weight_xc_N_2[i]; + kptr[8 + 1] = weight_xc_N_2[i + 1]; + kptr[8 + 2] = weight_xc_N_2[i + 2]; + kptr[8 + 3] = weight_xc_N_2[i + 3]; + kptr[8 + 4] = weight_xc_N_3[i]; + kptr[8 + 5] = weight_xc_N_3[i + 1]; + kptr[8 + 6] = weight_xc_N_3[i + 2]; + kptr[8 + 7] = weight_xc_N_3[i + 3]; + + kptr += 16; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_N_0[i]; + kptr[1] = weight_xc_N_0[i + 1]; + kptr[2] = weight_xc_N_1[i]; + kptr[3] = weight_xc_N_1[i + 1]; + kptr[4] = weight_xc_N_2[i]; + kptr[5] = weight_xc_N_2[i + 1]; + kptr[6] = weight_xc_N_3[i]; + kptr[7] = weight_xc_N_3[i + 1]; + + kptr += 8; + } + for (; i < size; i++) + { + kptr[0] = weight_xc_N_0[i]; + kptr[1] = weight_xc_N_1[i]; + kptr[2] = weight_xc_N_2[i]; + kptr[3] = weight_xc_N_3[i]; + + kptr += 4; + } + + float32x4_t _xc_R0 = vld1q_f32(weight_xc_int8_scales_ptr + q); + float32x4_t _xc_U0 = vld1q_f32(weight_xc_int8_scales_ptr + num_output + q); + float32x4_t _xc_N0 = vld1q_f32(weight_xc_int8_scales_ptr + num_output * 2 + q); + float32x4_t _hc_R0 = vld1q_f32(weight_hc_int8_scales_ptr + q); + float32x4_t _hc_U0 = vld1q_f32(weight_hc_int8_scales_ptr + num_output + q); + float32x4_t _hc_N0 = vld1q_f32(weight_hc_int8_scales_ptr + num_output * 2 + q); + +#if __aarch64__ + float32x4_t _one = vdupq_n_f32(1.f); + float32x4_t _reciprocal_xc_R0 = vdivq_f32(_one, _xc_R0); + float32x4_t _reciprocal_xc_U0 = vdivq_f32(_one, _xc_U0); + float32x4_t _reciprocal_xc_N0 = vdivq_f32(_one, _xc_N0); + float32x4_t _reciprocal_hc_R0 = vdivq_f32(_one, _hc_R0); + float32x4_t _reciprocal_hc_U0 = vdivq_f32(_one, _hc_U0); + float32x4_t _reciprocal_hc_N0 = vdivq_f32(_one, _hc_N0); +#else + float32x4_t _reciprocal_xc_R0 = vrecpeq_f32(_xc_R0); + float32x4_t _reciprocal_xc_U0 = vrecpeq_f32(_xc_U0); + float32x4_t _reciprocal_xc_N0 = vrecpeq_f32(_xc_N0); + _reciprocal_xc_R0 = vmulq_f32(vrecpsq_f32(_xc_R0, _reciprocal_xc_R0), _reciprocal_xc_R0); + _reciprocal_xc_U0 = vmulq_f32(vrecpsq_f32(_xc_U0, _reciprocal_xc_U0), _reciprocal_xc_U0); + _reciprocal_xc_N0 = vmulq_f32(vrecpsq_f32(_xc_N0, _reciprocal_xc_N0), _reciprocal_xc_N0); + float32x4_t _reciprocal_hc_R0 = vrecpeq_f32(_hc_R0); + float32x4_t _reciprocal_hc_U0 = vrecpeq_f32(_hc_U0); + float32x4_t _reciprocal_hc_N0 = vrecpeq_f32(_hc_N0); + _reciprocal_hc_R0 = vmulq_f32(vrecpsq_f32(_hc_R0, _reciprocal_hc_R0), _reciprocal_hc_R0); + _reciprocal_hc_U0 = vmulq_f32(vrecpsq_f32(_hc_U0, _reciprocal_hc_U0), _reciprocal_hc_U0); + _reciprocal_hc_N0 = vmulq_f32(vrecpsq_f32(_hc_N0, _reciprocal_hc_N0), _reciprocal_hc_N0); +#endif + + vst1q_f32(descales_ptr, _reciprocal_xc_R0); + vst1q_f32(descales_ptr + 4, _reciprocal_xc_U0); + vst1q_f32(descales_ptr + 8, _reciprocal_hc_R0); + vst1q_f32(descales_ptr + 12, _reciprocal_hc_U0); + vst1q_f32(descales_ptr + 16, _reciprocal_hc_N0); + vst1q_f32(descales_ptr + 20, _reciprocal_xc_N0); + } +#endif // __ARM_NEON + for (; q < num_output; q++) + { + bias_c_RUBNWN[0] = bias_c_R[q]; + bias_c_RUBNWN[1] = bias_c_U[q]; + bias_c_RUBNWN[2] = bias_c_BN[q]; + bias_c_RUBNWN[3] = bias_c_WN[q]; + + bias_c_RUBNWN += 4; + + const signed char* weight_xc_R = weight_xc_dr.row(num_output * 0 + q); + const signed char* weight_xc_U = weight_xc_dr.row(num_output * 1 + q); + const signed char* weight_xc_N = weight_xc_dr.row(num_output * 2 + q); + + const signed char* weight_hc_R = weight_hc_dr.row(num_output * 0 + q); + const signed char* weight_hc_U = weight_hc_dr.row(num_output * 1 + q); + const signed char* weight_hc_N = weight_hc_dr.row(num_output * 2 + q); + +#if __ARM_NEON + signed char* kptr = weight_data_tm_dr.row(q / 4 + q % 4); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 4 + q % 4); +#else + signed char* kptr = weight_data_tm_dr.row(q); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q); +#endif // __ARM_NEON + + for (int i = 0; i < size; i++) + { + kptr[0] = weight_xc_R[i]; + kptr[1] = weight_xc_U[i]; + + kptr += 2; + } + + for (int i = 0; i < num_output; i++) + { + kptr[0] = weight_hc_R[i]; + kptr[1] = weight_hc_U[i]; + + kptr += 2; + } + + for (int i = 0; i < num_output; i++) + { + kptr[0] = weight_hc_N[i]; + + kptr += 1; + } + + for (int i = 0; i < size; i++) + { + kptr[0] = weight_xc_N[i]; + + kptr += 1; + } + + descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[num_output * 0 + q]; + descales_ptr[1] = 1.f / weight_xc_int8_scales_ptr[num_output * 1 + q]; + descales_ptr[2] = 1.f / weight_hc_int8_scales_ptr[num_output * 0 + q]; + descales_ptr[3] = 1.f / weight_hc_int8_scales_ptr[num_output * 1 + q]; + descales_ptr[4] = 1.f / weight_hc_int8_scales_ptr[num_output * 2 + q]; + descales_ptr[5] = 1.f / weight_xc_int8_scales_ptr[num_output * 2 + q]; + } + } +} + +static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt) +{ + // TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + // TODO dispatch for __ARM_FEATURE_DOTPROD + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD + if (ncnn::cpu_support_arm_asimddp()) + { + gru_int8_asimddp(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, hidden_state, opt); + return; + } +#endif + + int size = bottom_blob_int8.w; + int T = bottom_blob_int8.h; + + int num_output = top_blob.w; + + // 2 x num_output +#if __ARM_NEON + Mat gates(4 * 2, num_output / 4 + num_output % 4, 4u, opt.workspace_allocator); +#else + Mat gates(2, num_output, 4u, opt.workspace_allocator); +#endif + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + float hidden_state_int8_scale = 1.f; + float hidden_state_int8_descale = 1.f; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scale = 127.f / absmax; + hidden_state_int8_descale = absmax / 127.f; + + signed char* hs = hidden_state_int8; + for (int i = 0; i < num_output; i++) + { + hs[i] = float2int8(hidden_state[i] * hidden_state_int8_scale); + } + } + } + + int remain_num_output_start = 0; +#if __ARM_NEON + int nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + // gate reset update + const float* bias_c_RUBNWN = (const float*)bias_c + q * 4; + + const signed char* kptr = weight_data_tm.row(q / 4); + + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 4); + + int32x4_t _gru_Rx0 = vdupq_n_s32(0); + int32x4_t _gru_Ux0 = vdupq_n_s32(0); + int i = 0; +#if __ARM_FEATURE_DOTPROD + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 7 < size; i += 8) + { + int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i)); + int8x16_t _xi0 = vreinterpretq_s8_s16(vdupq_lane_s32(_xi01, 0)); + int8x16_t _xi1 = vreinterpretq_s8_s16(vdupq_lane_s32(_xi01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _gru_Rx0 = vdotq_s32(_gru_Rx0, _w0, _xi0); + _gru_Ux0 = vdotq_s32(_gru_Ux0, _w1, _xi0); + _sum2 = vdotq_s32(_sum2, _w2, _xi1); + _sum3 = vdotq_s32(_sum3, _w3, _xi1); + + kptr += 64; + } + _gru_Rx0 = vaddq_s32(_gru_Rx0, _sum2); + _gru_Ux0 = vaddq_s32(_gru_Ux0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _xi = vreinterpretq_s8_s16(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(x + i)), 0)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _gru_Rx0 = vdotq_s32(_gru_Rx0, _w0, _xi); + _gru_Ux0 = vdotq_s32(_gru_Ux0, _w1, _xi); +#else + int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i)); + int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0)); + int8x8_t _xi1 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 1)); + int8x16_t _weight_xc_RU0 = vld1q_s8(kptr); + int8x16_t _weight_xc_RU1 = vld1q_s8(kptr + 16); + + int16x8_t _gru_Rx = vmull_s8(vget_low_s8(_weight_xc_RU0), _xi0); + int16x8_t _gru_Ux = vmull_s8(vget_high_s8(_weight_xc_RU0), _xi0); + _gru_Rx = vmlal_s8(_gru_Rx, vget_low_s8(_weight_xc_RU1), _xi1); + _gru_Ux = vmlal_s8(_gru_Ux, vget_high_s8(_weight_xc_RU1), _xi1); + + _gru_Rx0 = vpadalq_s16(_gru_Rx0, _gru_Rx); + _gru_Ux0 = vpadalq_s16(_gru_Ux0, _gru_Ux); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 32; + } + for (; i + 1 < size; i += 2) + { + int8x8_t _xi = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(x + i)), 0)); + int8x16_t _weight_xc_RU = vld1q_s8(kptr); + + int16x8_t _gru_Rx = vmull_s8(vget_low_s8(_weight_xc_RU), _xi); + int16x8_t _gru_Ux = vmull_s8(vget_high_s8(_weight_xc_RU), _xi); + + _gru_Rx0 = vpadalq_s16(_gru_Rx0, _gru_Rx); + _gru_Ux0 = vpadalq_s16(_gru_Ux0, _gru_Ux); + + kptr += 16; + } + for (; i < size; i++) + { + int8x8_t _xi = vdup_n_s8(x[i]); + int8x8_t _weight_xc_RU = vld1_s8(kptr); + + int16x8_t _gru_RxUx = vmull_s8(_weight_xc_RU, _xi); + _gru_Rx0 = vaddw_s16(_gru_Rx0, vget_low_s16(_gru_RxUx)); + _gru_Ux0 = vaddw_s16(_gru_Ux0, vget_high_s16(_gru_RxUx)); + + kptr += 8; + } + + int32x4_t _gru_Rh0 = vdupq_n_s32(0); + int32x4_t _gru_Uh0 = vdupq_n_s32(0); + i = 0; +#if __ARM_FEATURE_DOTPROD + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 7 < num_output; i += 8) + { + int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i)); + int8x16_t _h_cont0 = vreinterpretq_s8_s16(vdupq_lane_s32(_h_cont01, 0)); + int8x16_t _h_cont1 = vreinterpretq_s8_s16(vdupq_lane_s32(_h_cont01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _gru_Rh0 = vdotq_s32(_gru_Rh0, _w0, _h_cont0); + _gru_Uh0 = vdotq_s32(_gru_Uh0, _w1, _h_cont0); + _sum2 = vdotq_s32(_sum2, _w2, _h_cont1); + _sum3 = vdotq_s32(_sum3, _w3, _h_cont1); + + kptr += 64; + } + _gru_Rh0 = vaddq_s32(_gru_Rh0, _sum2); + _gru_Uh0 = vaddq_s32(_gru_Uh0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _h_cont = vreinterpretq_s8_s16(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(hs + i)), 0)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _gru_Rh0 = vdotq_s32(_gru_Rh0, _w0, _h_cont); + _gru_Uh0 = vdotq_s32(_gru_Uh0, _w1, _h_cont); +#else + int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i)); + int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0)); + int8x8_t _h_cont1 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 1)); + int8x16_t _weight_hc_RU0 = vld1q_s8(kptr); + int8x16_t _weight_hc_RU1 = vld1q_s8(kptr + 16); + + int16x8_t _gru_Rh = vmull_s8(vget_low_s8(_weight_hc_RU0), _h_cont0); + int16x8_t _gru_Uh = vmull_s8(vget_high_s8(_weight_hc_RU0), _h_cont0); + _gru_Rh = vmlal_s8(_gru_Rh, vget_low_s8(_weight_hc_RU1), _h_cont1); + _gru_Uh = vmlal_s8(_gru_Uh, vget_high_s8(_weight_hc_RU1), _h_cont1); + + _gru_Rh0 = vpadalq_s16(_gru_Rh0, _gru_Rh); + _gru_Uh0 = vpadalq_s16(_gru_Uh0, _gru_Uh); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 32; + } + for (; i + 1 < num_output; i += 2) + { + int8x8_t _h_cont = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(hs + i)), 0)); + int8x16_t _weight_hc_RU = vld1q_s8(kptr); + + int16x8_t _gru_Rh = vmull_s8(vget_low_s8(_weight_hc_RU), _h_cont); + int16x8_t _gru_Uh = vmull_s8(vget_high_s8(_weight_hc_RU), _h_cont); + + _gru_Rh0 = vpadalq_s16(_gru_Rh0, _gru_Rh); + _gru_Uh0 = vpadalq_s16(_gru_Uh0, _gru_Uh); + + kptr += 16; + } + for (; i < num_output; i++) + { + int8x8_t _h_cont = vdup_n_s8(hs[i]); + int8x8_t _weight_hc_RU = vld1_s8(kptr); + + int16x8_t _gru_RhUh = vmull_s8(_weight_hc_RU, _h_cont); + _gru_Rh0 = vaddw_s16(_gru_Rh0, vget_low_s16(_gru_RhUh)); + _gru_Uh0 = vaddw_s16(_gru_Uh0, vget_high_s16(_gru_RhUh)); + + kptr += 8; + } + + float32x4_t _descale_x = vdupq_n_f32(descale_x); + float32x4_t _descale_h = vdupq_n_f32(descale_h); + + float32x4_t _gru_R0 = vld1q_f32(bias_c_RUBNWN); + float32x4_t _gru_U0 = vld1q_f32(bias_c_RUBNWN + 4); + + float32x4_t _descale_xc_R0 = vld1q_f32(descales_ptr); + float32x4_t _descale_xc_U0 = vld1q_f32(descales_ptr + 4); + + _gru_R0 = vmlaq_f32(_gru_R0, vcvtq_f32_s32(_gru_Rx0), vmulq_f32(_descale_x, _descale_xc_R0)); + _gru_U0 = vmlaq_f32(_gru_U0, vcvtq_f32_s32(_gru_Ux0), vmulq_f32(_descale_x, _descale_xc_U0)); + + float32x4_t _descale_hc_R0 = vld1q_f32(descales_ptr + 8); + float32x4_t _descale_hc_U0 = vld1q_f32(descales_ptr + 12); + + _gru_R0 = vmlaq_f32(_gru_R0, vcvtq_f32_s32(_gru_Rh0), vmulq_f32(_descale_h, _descale_hc_R0)); + _gru_U0 = vmlaq_f32(_gru_U0, vcvtq_f32_s32(_gru_Uh0), vmulq_f32(_descale_h, _descale_hc_U0)); + + // sigmoid(R) + // sigmoid(U) + _gru_R0 = sigmoid_ps(_gru_R0); + _gru_U0 = sigmoid_ps(_gru_U0); + + // gate new + + int32x4_t _gru_Nh0 = vdupq_n_s32(0); + i = 0; +#if __ARM_FEATURE_DOTPROD + _sum2 = vdupq_n_s32(0); + for (; i + 7 < num_output; i += 8) + { + int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i)); + int8x16_t _h_cont0 = vreinterpretq_s8_s16(vdupq_lane_s32(_h_cont01, 0)); + int8x16_t _h_cont1 = vreinterpretq_s8_s16(vdupq_lane_s32(_h_cont01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _gru_Nh0 = vdotq_s32(_gru_Nh0, _w0, _h_cont0); + _sum2 = vdotq_s32(_sum2, _w1, _h_cont1); + + kptr += 32; + } + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum2); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _h_cont = vreinterpretq_s8_s16(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(hs + i)), 0)); + int8x16_t _w = vld1q_s8(kptr); + _gru_Nh0 = vdotq_s32(_gru_Nh0, _w, _h_cont); +#else + int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i)); + int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0)); + int8x8_t _h_cont1 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _gru_Nh = vmull_s8(vget_low_s8(_w01), _h_cont0); + _gru_Nh = vmlal_s8(_gru_Nh, vget_high_s8(_w01), _h_cont1); + _gru_Nh0 = vpadalq_s16(_gru_Nh0, _gru_Nh); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < num_output; i += 2) + { + int8x8_t _h_cont = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(hs + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _gru_Nh = vmull_s8(_w, _h_cont); + _gru_Nh0 = vpadalq_s16(_gru_Nh0, _gru_Nh); + + kptr += 8; + } + for (; i < num_output; i++) + { + int8x8_t _h_cont = vdup_n_s8(hs[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _gru_Nh = vmull_s8(_w, _h_cont); + _gru_Nh0 = vaddw_s16(_gru_Nh0, vget_low_s16(_gru_Nh)); + + kptr += 4; + } + + int32x4_t _gru_Nx0 = vdupq_n_s32(0); + i = 0; +#if __ARM_FEATURE_DOTPROD + _sum2 = vdupq_n_s32(0); + for (; i + 7 < size; i += 8) + { + int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i)); + int8x16_t _xi0 = vreinterpretq_s8_s16(vdupq_lane_s32(_xi01, 0)); + int8x16_t _xi1 = vreinterpretq_s8_s16(vdupq_lane_s32(_xi01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _gru_Nx0 = vdotq_s32(_gru_Nx0, _w0, _xi0); + _sum2 = vdotq_s32(_sum2, _w1, _xi1); + + kptr += 32; + } + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum2); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _xi = vreinterpretq_s8_s16(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(x + i)), 0)); + int8x16_t _w = vld1q_s8(kptr); + _gru_Nx0 = vdotq_s32(_gru_Nx0, _w, _xi); +#else + int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i)); + int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0)); + int8x8_t _xi1 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _gru_Nx = vmull_s8(vget_low_s8(_w01), _xi0); + _gru_Nx = vmlal_s8(_gru_Nx, vget_high_s8(_w01), _xi1); + _gru_Nx0 = vpadalq_s16(_gru_Nx0, _gru_Nx); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < size; i += 2) + { + int8x8_t _xi = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(x + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _gru_Nx = vmull_s8(_w, _xi); + _gru_Nx0 = vpadalq_s16(_gru_Nx0, _gru_Nx); + + kptr += 8; + } + for (; i < size; i++) + { + int8x8_t _xi = vdup_n_s8(x[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _gru_Nx = vmull_s8(_w, _xi); + _gru_Nx0 = vaddw_s16(_gru_Nx0, vget_low_s16(_gru_Nx)); + + kptr += 4; + } + + float32x4_t _gru_N0 = vld1q_f32(bias_c_RUBNWN + 8); + + float32x4_t _descale_hc_N0 = vld1q_f32(descales_ptr + 16); + + _gru_N0 = vmlaq_f32(_gru_N0, vcvtq_f32_s32(_gru_Nh0), vmulq_f32(_descale_h, _descale_hc_N0)); + + _gru_N0 = vmlaq_f32(vld1q_f32(bias_c_RUBNWN + 12), _gru_R0, _gru_N0); + + float32x4_t _descale_xc_N0 = vld1q_f32(descales_ptr + 20); + + _gru_N0 = vmlaq_f32(_gru_N0, vcvtq_f32_s32(_gru_Nx0), vmulq_f32(_descale_x, _descale_xc_N0)); + + // tanh(N) + _gru_N0 = tanh_ps(_gru_N0); + + float* gates_data = gates.row(q / 4); + + vst1q_f32(gates_data, _gru_U0); + vst1q_f32(gates_data + 4, _gru_N0); + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + // gate reset update + const float* bias_c_RUBNWN = (const float*)bias_c + q * 4; + +#if __ARM_NEON + const signed char* kptr = weight_data_tm.row(q / 4 + q % 4); + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 4 + q % 4); +#else + const signed char* kptr = weight_data_tm.row(q); + const float* descales_ptr = weight_data_tm_int8_descales.row(q); +#endif + + const float descale_xc_R = descales_ptr[0]; + const float descale_xc_U = descales_ptr[1]; + const float descale_hc_R = descales_ptr[2]; + const float descale_hc_U = descales_ptr[3]; + const float descale_hc_N = descales_ptr[4]; + const float descale_xc_N = descales_ptr[5]; + + int Rx = 0; + int Ux = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Rx += kptr[0] * xi; + Ux += kptr[1] * xi; + + kptr += 2; + } + + int Rh = 0; + int Uh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Rh += kptr[0] * h_cont; + Uh += kptr[1] * h_cont; + + kptr += 2; + } + + float R = bias_c_RUBNWN[0] + Rx * (descale_x * descale_xc_R) + Rh * (descale_h * descale_hc_R); + float U = bias_c_RUBNWN[1] + Ux * (descale_x * descale_xc_U) + Uh * (descale_h * descale_hc_U); + + // sigmoid(R) + // sigmoid(U) + R = 1.f / (1.f + expf(-R)); + U = 1.f / (1.f + expf(-U)); + + // gate new + + int Nh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Nh += kptr[0] * h_cont; + + kptr += 1; + } + + int Nx = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Nx += kptr[0] * xi; + + kptr += 1; + } + + float N = bias_c_RUBNWN[2] + Nh * (descale_h * descale_hc_N); + N = bias_c_RUBNWN[3] + R * N + Nx * (descale_x * descale_xc_N); + + // tanh(N) + N = tanhf(N); + +#if __ARM_NEON + float* gates_data = gates.row(q / 4 + q % 4); +#else + float* gates_data = gates.row(q); +#endif + + gates_data[0] = U; + gates_data[1] = N; + } + + // h_t := (1 - update) .* new + update .* h_{t-1} + float* output_data = top_blob.row(ti); + + float* hidden_ptr = hidden_state; + +#if __ARM_NEON + nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const float* gates_data = gates.row(q / 4); + + float32x4_t _gru_U0 = vld1q_f32(gates_data); + float32x4_t _gru_N0 = vld1q_f32(gates_data + 4); + + float32x4_t _gru_H0 = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U0), _gru_N0), vmulq_f32(_gru_U0, vld1q_f32(hidden_ptr + q))); + + vst1q_f32(hidden_ptr + q, _gru_H0); + + if (elemtype == 1) + { + // fp32 + vst1q_f32(output_data + q, _gru_H0); + } + if (elemtype == 2) + { + // fp16 + vst1_u16((unsigned short*)output_data + q, (uint16x4_t)vcvt_f16_f32(_gru_H0)); + } + if (elemtype == 4) + { + // bf16 + vst1_u16((unsigned short*)output_data + q, float2bfloat(_gru_H0)); + } + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { +#if __ARM_NEON + const float* gates_data = gates.row(q / 4 + q % 4); +#else + const float* gates_data = gates.row(q); +#endif + + float U = gates_data[0]; + float N = gates_data[1]; + + float H = (1 - U) * N + U * hidden_ptr[q]; + + hidden_ptr[q] = H; + + if (elemtype == 1) + { + output_data[q] = H; + } + if (elemtype == 2) + { + ((unsigned short*)output_data)[q] = float32_to_float16(H); + } + if (elemtype == 4) + { + ((unsigned short*)output_data)[q] = float32_to_bfloat16(H); + } + } + } +} diff --git a/src/layer/gru.cpp b/src/layer/gru.cpp index 1276e677222..6da1f715d7a 100644 --- a/src/layer/gru.cpp +++ b/src/layer/gru.cpp @@ -191,12 +191,65 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma if (gates.empty()) return -100; + // dynamic quantize bottom_blob + Mat bottom_blob_int8(size, T, (size_t)1u, 1, opt.workspace_allocator); + Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.workspace_allocator); + { + for (int t = 0; t < T; t++) + { + const float* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(x[i])); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + } + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_quant); + } + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + Mat hidden_state_int8_scales(1, (size_t)4u, 1, opt.workspace_allocator); + // unroll for (int t = 0; t < T; t++) { int ti = reverse ? T - 1 - t : t; - const float* x = bottom_blob.row(ti); + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8_scales[0] = 1.f; + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scales[0] = 127.f / absmax; + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(hidden_state, hidden_state_int8, hidden_state_int8_scales, opt_quant); + } + } + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = 1.f / bottom_blob_int8_scales[ti]; + const float descale_h = 1.f / hidden_state_int8_scales[0]; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_output; q++) { @@ -216,25 +269,29 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma const float descale_hc_R = 1.f / weight_hc_int8_scales[num_output * 0 + q]; const float descale_hc_U = 1.f / weight_hc_int8_scales[num_output * 1 + q]; - float R = bias_c_R[q]; - float U = bias_c_U[q]; - + int Rx = 0; + int Ux = 0; for (int i = 0; i < size; i++) { - float xi = x[i]; + signed char xi = x[i]; - R += weight_xc_int8_R[i] * descale_xc_R * xi; - U += weight_xc_int8_U[i] * descale_xc_U * xi; + Rx += weight_xc_int8_R[i] * xi; + Ux += weight_xc_int8_U[i] * xi; } + int Rh = 0; + int Uh = 0; for (int i = 0; i < num_output; i++) { - float h_cont = hidden_state[i]; + signed char h_cont = hs[i]; - R += weight_hc_int8_R[i] * descale_hc_R * h_cont; - U += weight_hc_int8_U[i] * descale_hc_U * h_cont; + Rh += weight_hc_int8_R[i] * h_cont; + Uh += weight_hc_int8_U[i] * h_cont; } + float R = bias_c_R[q] + Rx * (descale_x * descale_xc_R) + Rh * (descale_h * descale_hc_R); + float U = bias_c_U[q] + Ux * (descale_x * descale_xc_U) + Uh * (descale_h * descale_hc_U); + // sigmoid(R) // sigmoid(U) R = 1.f / (1.f + expf(-R)); @@ -250,24 +307,25 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma const float descale_xc_N = 1.f / weight_xc_int8_scales[num_output * 2 + q]; const float descale_hc_N = 1.f / weight_hc_int8_scales[num_output * 2 + q]; - float N = bias_c_BN[q]; - + int Nh = 0; for (int i = 0; i < num_output; i++) { - float h_cont = hidden_state[i]; + signed char h_cont = hs[i]; - N += weight_hc_int8_N[i] * descale_hc_N * h_cont; + Nh += weight_hc_int8_N[i] * h_cont; } - N = bias_c_WN[q] + R * N; - + int Nx = 0; for (int i = 0; i < size; i++) { - float xi = x[i]; + signed char xi = x[i]; - N += weight_xc_int8_N[i] * descale_xc_N * xi; + Nx += weight_xc_int8_N[i] * xi; } + float N = bias_c_BN[q] + Nh * (descale_h * descale_hc_N); + N = bias_c_WN[q] + R * N + Nx * (descale_x * descale_xc_N); + // tanh(N) N = tanhf(N); diff --git a/tools/quantize/ncnn2int8.cpp b/tools/quantize/ncnn2int8.cpp index 763c1b09252..64b56798a01 100644 --- a/tools/quantize/ncnn2int8.cpp +++ b/tools/quantize/ncnn2int8.cpp @@ -510,7 +510,7 @@ int NetQuantize::quantize_gru() { absmax = std::max(absmax, (float)fabs(weight_xc_ptr[i])); } - weight_xc_data_int8_scales[d * gru->num_output + q] = 127 / absmax; + weight_xc_data_int8_scales[d * gru->num_output * 3 + q] = 127 / absmax; } { @@ -520,7 +520,7 @@ int NetQuantize::quantize_gru() { absmax = std::max(absmax, (float)fabs(weight_hc_ptr[i])); } - weight_hc_data_int8_scales[d * gru->num_output + q] = 127 / absmax; + weight_hc_data_int8_scales[d * gru->num_output * 3 + q] = 127 / absmax; } } }