From 3ddb97b91ecd8e6d25fd0808652fed566ebb3867 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 29 Apr 2024 15:24:01 +0800 Subject: [PATCH] opt --- src/layer/arm/lstm_int8.h | 271 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 259 insertions(+), 12 deletions(-) diff --git a/src/layer/arm/lstm_int8.h b/src/layer/arm/lstm_int8.h index 9a592d4f7f6..9ad7de33551 100644 --- a/src/layer/arm/lstm_int8.h +++ b/src/layer/arm/lstm_int8.h @@ -77,7 +77,42 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x signed char* kptr = weight_data_tm_dr.row(q); float* descales_ptr = weight_data_tm_int8_descales_dr.row(q); - for (int i = 0; i < size; i++) + int i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_I[i + 1]; + kptr[2] = weight_xc_I[i + 2]; + kptr[3] = weight_xc_I[i + 3]; + kptr[4] = weight_xc_F[i]; + kptr[5] = weight_xc_F[i + 1]; + kptr[6] = weight_xc_F[i + 2]; + kptr[7] = weight_xc_F[i + 3]; + kptr[8 + 0] = weight_xc_O[i]; + kptr[8 + 1] = weight_xc_O[i + 1]; + kptr[8 + 2] = weight_xc_O[i + 2]; + kptr[8 + 3] = weight_xc_O[i + 3]; + kptr[8 + 4] = weight_xc_G[i]; + kptr[8 + 5] = weight_xc_G[i + 1]; + kptr[8 + 6] = weight_xc_G[i + 2]; + kptr[8 + 7] = weight_xc_G[i + 3]; + kptr += 16; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_I[i + 1]; + kptr[2] = weight_xc_F[i]; + kptr[3] = weight_xc_F[i + 1]; + kptr[4] = weight_xc_O[i]; + kptr[5] = weight_xc_O[i + 1]; + kptr[6] = weight_xc_G[i]; + kptr[7] = weight_xc_G[i + 1]; + kptr += 8; + } + for (; i < size; i++) { kptr[0] = weight_xc_I[i]; kptr[1] = weight_xc_F[i]; @@ -86,7 +121,42 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr += 4; } - for (int i = 0; i < num_output; i++) + i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_I[i + 1]; + kptr[2] = weight_hc_I[i + 2]; + kptr[3] = weight_hc_I[i + 3]; + kptr[4] = weight_hc_F[i]; + kptr[5] = weight_hc_F[i + 1]; + kptr[6] = weight_hc_F[i + 2]; + kptr[7] = weight_hc_F[i + 3]; + kptr[8 + 0] = weight_hc_O[i]; + kptr[8 + 1] = weight_hc_O[i + 1]; + kptr[8 + 2] = weight_hc_O[i + 2]; + kptr[8 + 3] = weight_hc_O[i + 3]; + kptr[8 + 4] = weight_hc_G[i]; + kptr[8 + 5] = weight_hc_G[i + 1]; + kptr[8 + 6] = weight_hc_G[i + 2]; + kptr[8 + 7] = weight_hc_G[i + 3]; + kptr += 16; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_I[i + 1]; + kptr[2] = weight_hc_F[i]; + kptr[3] = weight_hc_F[i + 1]; + kptr[4] = weight_hc_O[i]; + kptr[5] = weight_hc_O[i + 1]; + kptr[6] = weight_hc_G[i]; + kptr[7] = weight_hc_G[i + 1]; + kptr += 8; + } + for (; i < num_output; i++) { kptr[0] = weight_hc_I[i]; kptr[1] = weight_hc_F[i]; @@ -183,15 +253,184 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d const signed char* kptr = weight_data_tm.row(q); const float* descales_ptr = weight_data_tm_int8_descales.row(q); - const float descale_xc_I = descales_ptr[0]; - const float descale_xc_F = descales_ptr[1]; - const float descale_xc_O = descales_ptr[2]; - const float descale_xc_G = descales_ptr[3]; - const float descale_hc_I = descales_ptr[4]; - const float descale_hc_F = descales_ptr[5]; - const float descale_hc_O = descales_ptr[6]; - const float descale_hc_G = descales_ptr[7]; + float* gates_data = gates.row(q); + +#if __ARM_NEON + int32x4_t _lstm_IFOGx0 = vdupq_n_s32(0); + int i = 0; +#if __ARM_FEATURE_DOTPROD + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 15 < size; i += 16) + { + int32x4_t _xi01 = vreinterpretq_s32_s8(vld1q_s8(x + i)); + int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 0)); + int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 1)); + int8x16_t _xi2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 2)); + int8x16_t _xi3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 3)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _lstm_IFOGx0 = vdotq_s32(_lstm_IFOGx0, _w0, _xi0); + _sum1 = vdotq_s32(_sum1, _w1, _xi1); + _sum2 = vdotq_s32(_sum2, _w2, _xi2); + _sum3 = vdotq_s32(_sum3, _w3, _xi3); + + kptr += 64; + } + for (; i + 7 < size; i += 8) + { + int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i)); + int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 0)); + int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _lstm_IFOGx0 = vdotq_s32(_lstm_IFOGx0, _w0, _xi0); + _sum1 = vdotq_s32(_sum1, _w1, _xi1); + + kptr += 32; + } + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _xi = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(x + i)), 0)); + int8x16_t _w = vld1q_s8(kptr); + _lstm_IFOGx0 = vdotq_s32(_lstm_IFOGx0, _w, _xi); +#else + int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i)); + int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0)); + int8x8_t _xi1 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _lstm_IFOGx = vmull_s8(vget_low_s8(_w01), _xi0); + _lstm_IFOGx = vmlal_s8(_lstm_IFOGx, vget_high_s8(_w01), _xi1); + _lstm_IFOGx0 = vpadalq_s16(_lstm_IFOGx0, _lstm_IFOGx); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < size; i += 2) + { + int8x8_t _xi = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(x + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _lstm_IFOGx = vmull_s8(_w, _xi); + _lstm_IFOGx0 = vpadalq_s16(_lstm_IFOGx0, _lstm_IFOGx); + + kptr += 8; + } + for (; i < size; i++) + { + int8x8_t _xi = vdup_n_s8(x[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _lstm_IFOGx = vmull_s8(_w, _xi); + _lstm_IFOGx0 = vaddw_s16(_lstm_IFOGx0, vget_low_s16(_lstm_IFOGx)); + kptr += 4; + } + + int32x4_t _lstm_IFOGh0 = vdupq_n_s32(0); + i = 0; +#if __ARM_FEATURE_DOTPROD + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < num_output; i += 16) + { + int32x4_t _h_cont01 = vreinterpretq_s32_s8(vld1q_s8(hs + i)); + int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 0)); + int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 1)); + int8x16_t _h_cont2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 2)); + int8x16_t _h_cont3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 3)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _lstm_IFOGh0 = vdotq_s32(_lstm_IFOGh0, _w0, _h_cont0); + _sum1 = vdotq_s32(_sum1, _w1, _h_cont1); + _sum2 = vdotq_s32(_sum2, _w2, _h_cont2); + _sum3 = vdotq_s32(_sum3, _w3, _h_cont3); + + kptr += 64; + } + for (; i + 7 < num_output; i += 8) + { + int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i)); + int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 0)); + int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _lstm_IFOGh0 = vdotq_s32(_lstm_IFOGh0, _w0, _h_cont0); + _sum1 = vdotq_s32(_sum1, _w1, _h_cont1); + + kptr += 32; + } + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _h_cont = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(hs + i)), 0)); + int8x16_t _w = vld1q_s8(kptr); + _lstm_IFOGh0 = vdotq_s32(_lstm_IFOGh0, _w, _h_cont); +#else + int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i)); + int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0)); + int8x8_t _h_cont1 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _lstm_IFOGh = vmull_s8(vget_low_s8(_w01), _h_cont0); + _lstm_IFOGh = vmlal_s8(_lstm_IFOGh, vget_high_s8(_w01), _h_cont1); + _lstm_IFOGh0 = vpadalq_s16(_lstm_IFOGh0, _lstm_IFOGh); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < num_output; i += 2) + { + int8x8_t _h_cont = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(hs + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _lstm_IFOGh = vmull_s8(_w, _h_cont); + _lstm_IFOGh0 = vpadalq_s16(_lstm_IFOGh0, _lstm_IFOGh); + + kptr += 8; + } + for (; i < num_output; i++) + { + int8x8_t _h_cont = vdup_n_s8(hs[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _lstm_IFOGh = vmull_s8(_w, _h_cont); + _lstm_IFOGh0 = vaddw_s16(_lstm_IFOGh0, vget_low_s16(_lstm_IFOGh)); + + kptr += 4; + } + + float32x4_t _descale_x = vdupq_n_f32(descale_x); + float32x4_t _descale_h = vdupq_n_f32(descale_h); + + float32x4_t _lstm_IFOG0 = vld1q_f32(bias_c_IFOG); + + float32x4_t _descale_xc_IFOG = vld1q_f32(descales_ptr); + + _lstm_IFOG0 = vmlaq_f32(_lstm_IFOG0, vcvtq_f32_s32(_lstm_IFOGx0), vmulq_f32(_descale_x, _descale_xc_IFOG)); + + float32x4_t _descale_hc_IFOG = vld1q_f32(descales_ptr + 4); + + _lstm_IFOG0 = vmlaq_f32(_lstm_IFOG0, vcvtq_f32_s32(_lstm_IFOGh0), vmulq_f32(_descale_h, _descale_hc_IFOG)); + + vst1q_f32(gates_data, _lstm_IFOG0); +#else int Ix = 0; int Fx = 0; int Ox = 0; @@ -224,17 +463,25 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 4; } + const float descale_xc_I = descales_ptr[0]; + const float descale_xc_F = descales_ptr[1]; + const float descale_xc_O = descales_ptr[2]; + const float descale_xc_G = descales_ptr[3]; + const float descale_hc_I = descales_ptr[4]; + const float descale_hc_F = descales_ptr[5]; + const float descale_hc_O = descales_ptr[6]; + const float descale_hc_G = descales_ptr[7]; + float I = bias_c_IFOG[0] + Ix * (descale_x * descale_xc_I) + Ih * (descale_h * descale_hc_I); float F = bias_c_IFOG[1] + Fx * (descale_x * descale_xc_F) + Fh * (descale_h * descale_hc_F); float O = bias_c_IFOG[2] + Ox * (descale_x * descale_xc_O) + Oh * (descale_h * descale_hc_O); float G = bias_c_IFOG[3] + Gx * (descale_x * descale_xc_G) + Gh * (descale_h * descale_hc_G); - float* gates_data = gates.row(q); - gates_data[0] = I; gates_data[1] = F; gates_data[2] = O; gates_data[3] = G; +#endif // __ARM_NEON } // lstm unit