Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 23, 2024
1 parent 0d66049 commit 7546165
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions src/layer/arm/lstm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const M
const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q);

#if __ARM_NEON
float32x4_t _descale_xc_IFOG = vld1q_f32(weight_xc_int8_descales_IFOG);
float32x4_t _descale_hc_IFOG = vld1q_f32(weight_hc_int8_descales_IFOG);
float32x4_t _descale_xc = vld1q_f32(weight_xc_int8_descales_IFOG);
float32x4_t _descale_hc = vld1q_f32(weight_hc_int8_descales_IFOG);

float32x4_t _IFOG = vld1q_f32(bias_c_IFOG);
float32x4_t _sum1 = vdupq_n_f32(0.f);
Expand Down Expand Up @@ -495,10 +495,10 @@ static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const M
float32x4_t _weight_xc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG01)));
float32x4_t _weight_xc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_IFOG23)));
float32x4_t _weight_xc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG23)));
_weight_xc_IFOG_0 = vmulq_f32(_weight_xc_IFOG_0, _descale_xc_IFOG);
_weight_xc_IFOG_1 = vmulq_f32(_weight_xc_IFOG_1, _descale_xc_IFOG);
_weight_xc_IFOG_2 = vmulq_f32(_weight_xc_IFOG_2, _descale_xc_IFOG);
_weight_xc_IFOG_3 = vmulq_f32(_weight_xc_IFOG_3, _descale_xc_IFOG);
_weight_xc_IFOG_0 = vmulq_f32(_weight_xc_IFOG_0, _descale_xc);
_weight_xc_IFOG_1 = vmulq_f32(_weight_xc_IFOG_1, _descale_xc);
_weight_xc_IFOG_2 = vmulq_f32(_weight_xc_IFOG_2, _descale_xc);
_weight_xc_IFOG_3 = vmulq_f32(_weight_xc_IFOG_3, _descale_xc);

#if __aarch64__
_IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0);
Expand All @@ -522,7 +522,7 @@ static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const M
#if __ARM_NEON
float32x4_t _xi = vdupq_n_f32(xi);
float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG)))));
_weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc_IFOG);
_weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc);
_IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi);
#else
I += weight_xc_int8_IFOG[0] * descale_xc_I * xi;
Expand All @@ -547,10 +547,10 @@ static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const M
float32x4_t _weight_hc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG01)));
float32x4_t _weight_hc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_IFOG23)));
float32x4_t _weight_hc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG23)));
_weight_hc_IFOG_0 = vmulq_f32(_weight_hc_IFOG_0, _descale_hc_IFOG);
_weight_hc_IFOG_1 = vmulq_f32(_weight_hc_IFOG_1, _descale_hc_IFOG);
_weight_hc_IFOG_2 = vmulq_f32(_weight_hc_IFOG_2, _descale_hc_IFOG);
_weight_hc_IFOG_3 = vmulq_f32(_weight_hc_IFOG_3, _descale_hc_IFOG);
_weight_hc_IFOG_0 = vmulq_f32(_weight_hc_IFOG_0, _descale_hc);
_weight_hc_IFOG_1 = vmulq_f32(_weight_hc_IFOG_1, _descale_hc);
_weight_hc_IFOG_2 = vmulq_f32(_weight_hc_IFOG_2, _descale_hc);
_weight_hc_IFOG_3 = vmulq_f32(_weight_hc_IFOG_3, _descale_hc);

#if __aarch64__
_IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0);
Expand All @@ -574,7 +574,7 @@ static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const M
#if __ARM_NEON
float32x4_t _h_cont = vdupq_n_f32(h_cont);
float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG)))));
_weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc_IFOG);
_weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc);
_IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
#else
I += weight_hc_int8_IFOG[0] * descale_hc_I * h_cont;
Expand Down Expand Up @@ -1373,8 +1373,8 @@ static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q);

#if __ARM_NEON
float32x4_t _descale_xc_IFOG = vld1q_f32(weight_xc_int8_descales_IFOG);
float32x4_t _descale_hc_IFOG = vld1q_f32(weight_hc_int8_descales_IFOG);
float32x4_t _descale_xc = vld1q_f32(weight_xc_int8_descales_IFOG);
float32x4_t _descale_hc = vld1q_f32(weight_hc_int8_descales_IFOG);

float32x4_t _IFOG = bfloat2float(vld1_u16(bias_c_IFOG));
float32x4_t _sum1 = vdupq_n_f32(0.f);
Expand Down Expand Up @@ -1410,10 +1410,10 @@ static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
float32x4_t _weight_xc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG01)));
float32x4_t _weight_xc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_IFOG23)));
float32x4_t _weight_xc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_IFOG23)));
_weight_xc_IFOG_0 = vmulq_f32(_weight_xc_IFOG_0, _descale_xc_IFOG);
_weight_xc_IFOG_1 = vmulq_f32(_weight_xc_IFOG_1, _descale_xc_IFOG);
_weight_xc_IFOG_2 = vmulq_f32(_weight_xc_IFOG_2, _descale_xc_IFOG);
_weight_xc_IFOG_3 = vmulq_f32(_weight_xc_IFOG_3, _descale_xc_IFOG);
_weight_xc_IFOG_0 = vmulq_f32(_weight_xc_IFOG_0, _descale_xc);
_weight_xc_IFOG_1 = vmulq_f32(_weight_xc_IFOG_1, _descale_xc);
_weight_xc_IFOG_2 = vmulq_f32(_weight_xc_IFOG_2, _descale_xc);
_weight_xc_IFOG_3 = vmulq_f32(_weight_xc_IFOG_3, _descale_xc);

#if __aarch64__
_IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0);
Expand All @@ -1437,7 +1437,7 @@ static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c

float32x4_t _xi = bfloat2float(vdup_n_u16(xi));
float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG)))));
_weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc_IFOG);
_weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc);
_IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi);
#else
float xi = bfloat16_to_float32(x[i]);
Expand All @@ -1464,10 +1464,10 @@ static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
float32x4_t _weight_hc_IFOG_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG01)));
float32x4_t _weight_hc_IFOG_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_IFOG23)));
float32x4_t _weight_hc_IFOG_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_IFOG23)));
_weight_hc_IFOG_0 = vmulq_f32(_weight_hc_IFOG_0, _descale_hc_IFOG);
_weight_hc_IFOG_1 = vmulq_f32(_weight_hc_IFOG_1, _descale_hc_IFOG);
_weight_hc_IFOG_2 = vmulq_f32(_weight_hc_IFOG_2, _descale_hc_IFOG);
_weight_hc_IFOG_3 = vmulq_f32(_weight_hc_IFOG_3, _descale_hc_IFOG);
_weight_hc_IFOG_0 = vmulq_f32(_weight_hc_IFOG_0, _descale_hc);
_weight_hc_IFOG_1 = vmulq_f32(_weight_hc_IFOG_1, _descale_hc);
_weight_hc_IFOG_2 = vmulq_f32(_weight_hc_IFOG_2, _descale_hc);
_weight_hc_IFOG_3 = vmulq_f32(_weight_hc_IFOG_3, _descale_hc);

#if __aarch64__
_IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0);
Expand All @@ -1491,7 +1491,7 @@ static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
#if __ARM_NEON
float32x4_t _h_cont = vdupq_n_f32(h_cont);
float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG)))));
_weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc_IFOG);
_weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc);
_IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
#else
I += weight_hc_int8_IFOG[0] * descale_hc_I * h_cont;
Expand Down

0 comments on commit 7546165

Please sign in to comment.